refactor CRUD utils
This commit is contained in:
parent
b42a5b44f3
commit
8194027c36
6 changed files with 72 additions and 58 deletions
|
@ -0,0 +1 @@
|
||||||
|
from .connection import Connection
|
|
@ -1,31 +0,0 @@
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from passlib.context import CryptContext
|
|
||||||
|
|
||||||
from . import models, schemas
|
|
||||||
|
|
||||||
|
|
||||||
def get_user(db: Session, name: str):
|
|
||||||
return (db
|
|
||||||
.query(models.User)
|
|
||||||
.filter(models.User.name == name).first())
|
|
||||||
|
|
||||||
|
|
||||||
def create_user(
|
|
||||||
db: Session,
|
|
||||||
user: schemas.UserCreate,
|
|
||||||
capabilities: list[str],
|
|
||||||
crypt_context: CryptContext,
|
|
||||||
):
|
|
||||||
db_user = models.User(
|
|
||||||
name=user.name,
|
|
||||||
password=crypt_context.hash(user.password),
|
|
||||||
capabilities=[
|
|
||||||
models.UserCapability(capability=capability)
|
|
||||||
for capability in capabilities
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(db_user)
|
|
||||||
db.commit()
|
|
||||||
db.refresh(db_user)
|
|
||||||
return db_user
|
|
|
@ -1,5 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pydantic import BaseModel
|
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
from pydantic import BaseModel, validator
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from . import models
|
||||||
|
|
||||||
|
|
||||||
class CertificateBase(BaseModel):
|
class CertificateBase(BaseModel):
|
||||||
|
@ -20,6 +27,15 @@ class Certificate(CertificateBase):
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
capabilities: list[str]
|
||||||
|
|
||||||
|
@validator("capabilities", pre=True)
|
||||||
|
@classmethod
|
||||||
|
def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]:
|
||||||
|
return [
|
||||||
|
capability.capability
|
||||||
|
for capability in value
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
|
@ -27,12 +43,46 @@ class UserCreate(UserBase):
|
||||||
|
|
||||||
|
|
||||||
class User(UserBase):
|
class User(UserBase):
|
||||||
capabilities: list[str]
|
|
||||||
certificates: list[Certificate]
|
certificates: list[Certificate]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(
|
||||||
|
cls,
|
||||||
|
db: Session,
|
||||||
|
name: str,
|
||||||
|
) -> User:
|
||||||
|
user = (db
|
||||||
|
.query(models.User)
|
||||||
|
.filter(models.User.name == name)
|
||||||
|
.first())
|
||||||
|
|
||||||
|
return cls.from_orm(user)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
db: Session,
|
||||||
|
user: UserCreate,
|
||||||
|
crypt_context: CryptContext,
|
||||||
|
) -> User:
|
||||||
|
user = models.User(
|
||||||
|
name=user.name,
|
||||||
|
password=crypt_context.hash(user.password),
|
||||||
|
capabilities=[
|
||||||
|
models.UserCapability(capability=capability)
|
||||||
|
for capability in user.capabilities
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
|
||||||
|
return cls.from_orm(user)
|
||||||
|
|
||||||
|
|
||||||
class DistinguishedNameBase(BaseModel):
|
class DistinguishedNameBase(BaseModel):
|
||||||
cn_only: bool
|
cn_only: bool
|
||||||
|
|
|
@ -4,7 +4,7 @@ import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from .config import Config, Settings
|
from .config import Config, Settings
|
||||||
from .db.connection import Connection
|
from .db import Connection, schemas
|
||||||
from .routers import admin, user
|
from .routers import admin, user
|
||||||
|
|
||||||
settings = Settings.get()
|
settings = Settings.get()
|
||||||
|
@ -29,26 +29,21 @@ api = FastAPI(
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/api", api)
|
app.mount("/api", api)
|
||||||
|
|
||||||
|
api.include_router(admin.router)
|
||||||
|
api.include_router(user.router)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def on_startup():
|
async def on_startup() -> None:
|
||||||
# always include admin router
|
|
||||||
api.include_router(admin.router)
|
|
||||||
|
|
||||||
if (current_config := await Config.get()) is not None:
|
if (current_config := await Config.get()) is not None:
|
||||||
Connection.connect(current_config.db_engine)
|
Connection.connect(current_config.db_engine)
|
||||||
|
|
||||||
# async for db in connection.get():
|
async for db in Connection.get():
|
||||||
# user = crud.get_user(db, "admin")
|
user = schemas.User.get(db, "admin")
|
||||||
# print(user.name)
|
print(str(user))
|
||||||
# for cap in user.capabilities:
|
|
||||||
# print(cap.capability)
|
|
||||||
|
|
||||||
# include other routers
|
|
||||||
api.include_router(user.router)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"kiwi_vpn_api.main:app",
|
"kiwi_vpn_api.main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
|
|
|
@ -4,8 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..db import crud, schemas
|
from ..db import Connection, schemas
|
||||||
from ..db.connection import Connection
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
|
|
||||||
|
@ -32,9 +31,8 @@ async def set_config(
|
||||||
if new_config.jwt.secret is None:
|
if new_config.jwt.secret is None:
|
||||||
new_config.jwt.secret = token_hex(32)
|
new_config.jwt.secret = token_hex(32)
|
||||||
|
|
||||||
Connection.connect(new_config.db_engine)
|
|
||||||
|
|
||||||
Config.set(new_config)
|
Config.set(new_config)
|
||||||
|
Connection.connect(new_config.db_engine)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -58,12 +56,12 @@ async def add_user(
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
crud.create_user(
|
schemas.User.create(
|
||||||
db=db,
|
db=db,
|
||||||
user=schemas.UserCreate(
|
user=schemas.UserCreate(
|
||||||
name=user_name,
|
name=user_name,
|
||||||
password=user_password,
|
password=user_password,
|
||||||
|
capabilities=["admin"],
|
||||||
),
|
),
|
||||||
capabilities=["admin"],
|
|
||||||
crypt_context=current_config.crypt_context,
|
crypt_context=current_config.crypt_context,
|
||||||
)
|
)
|
||||||
|
|
|
@ -7,11 +7,12 @@ from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..db import crud
|
from ..db import Connection, schemas
|
||||||
from ..db.connection import Connection
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/user")
|
router = APIRouter(prefix="/user")
|
||||||
SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token")
|
SCHEME = OAuth2PasswordBearer(
|
||||||
|
tokenUrl=f".{router.prefix}/token",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
|
@ -37,7 +38,7 @@ def create_access_token(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth", response_model=Token)
|
@router.post("/auth", response_model=Token)
|
||||||
async def login_for_access_token(
|
async def login(
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
config: Config = Depends(Config.get),
|
config: Config = Depends(Config.get),
|
||||||
db: Session = Depends(Connection.get),
|
db: Session = Depends(Connection.get),
|
||||||
|
@ -48,7 +49,7 @@ async def login_for_access_token(
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
user = crud.get_user(db, form_data.username)
|
user = schemas.User.get(db, form_data.username)
|
||||||
if user is None:
|
if user is None:
|
||||||
config.crypt_context.dummy_verify()
|
config.crypt_context.dummy_verify()
|
||||||
raise credentials_exception
|
raise credentials_exception
|
||||||
|
|
Loading…
Reference in a new issue