diff --git a/api/kiwi_vpn_api/db/__init__.py b/api/kiwi_vpn_api/db/__init__.py index e69de29..f70650d 100644 --- a/api/kiwi_vpn_api/db/__init__.py +++ b/api/kiwi_vpn_api/db/__init__.py @@ -0,0 +1 @@ +from .connection import Connection diff --git a/api/kiwi_vpn_api/db/crud.py b/api/kiwi_vpn_api/db/crud.py deleted file mode 100644 index 049a686..0000000 --- a/api/kiwi_vpn_api/db/crud.py +++ /dev/null @@ -1,31 +0,0 @@ -from sqlalchemy.orm import Session -from passlib.context import CryptContext - -from . import models, schemas - - -def get_user(db: Session, name: str): - return (db - .query(models.User) - .filter(models.User.name == name).first()) - - -def create_user( - db: Session, - user: schemas.UserCreate, - capabilities: list[str], - crypt_context: CryptContext, -): - db_user = models.User( - name=user.name, - password=crypt_context.hash(user.password), - capabilities=[ - models.UserCapability(capability=capability) - for capability in capabilities - ] - ) - - db.add(db_user) - db.commit() - db.refresh(db_user) - return db_user diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemas.py index 47eab00..29b945b 100644 --- a/api/kiwi_vpn_api/db/schemas.py +++ b/api/kiwi_vpn_api/db/schemas.py @@ -1,5 +1,12 @@ +from __future__ import annotations + from datetime import datetime -from pydantic import BaseModel + +from passlib.context import CryptContext +from pydantic import BaseModel, validator +from sqlalchemy.orm import Session + +from . import models class CertificateBase(BaseModel): @@ -20,6 +27,15 @@ class Certificate(CertificateBase): class UserBase(BaseModel): name: str + capabilities: list[str] + + @validator("capabilities", pre=True) + @classmethod + def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]: + return [ + capability.capability + for capability in value + ] class UserCreate(UserBase): @@ -27,12 +43,46 @@ class UserCreate(UserBase): class User(UserBase): - capabilities: list[str] certificates: list[Certificate] class Config: orm_mode = True + @classmethod + def get( + cls, + db: Session, + name: str, + ) -> User: + user = (db + .query(models.User) + .filter(models.User.name == name) + .first()) + + return cls.from_orm(user) + + @classmethod + def create( + cls, + db: Session, + user: UserCreate, + crypt_context: CryptContext, + ) -> User: + user = models.User( + name=user.name, + password=crypt_context.hash(user.password), + capabilities=[ + models.UserCapability(capability=capability) + for capability in user.capabilities + ] + ) + + db.add(user) + db.commit() + db.refresh(user) + + return cls.from_orm(user) + class DistinguishedNameBase(BaseModel): cn_only: bool diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 2612862..b798f81 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -4,7 +4,7 @@ import uvicorn from fastapi import FastAPI from .config import Config, Settings -from .db.connection import Connection +from .db import Connection, schemas from .routers import admin, user settings = Settings.get() @@ -29,26 +29,21 @@ api = FastAPI( app = FastAPI() app.mount("/api", api) +api.include_router(admin.router) +api.include_router(user.router) + @app.on_event("startup") -async def on_startup(): - # always include admin router - api.include_router(admin.router) - +async def on_startup() -> None: if (current_config := await Config.get()) is not None: Connection.connect(current_config.db_engine) - # async for db in connection.get(): - # user = crud.get_user(db, "admin") - # print(user.name) - # for cap in user.capabilities: - # print(cap.capability) - - # include other routers - api.include_router(user.router) + async for db in Connection.get(): + user = schemas.User.get(db, "admin") + print(str(user)) -def main(): +def main() -> None: uvicorn.run( "kiwi_vpn_api.main:app", host="0.0.0.0", diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 76f999c..68b2488 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -4,8 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from ..config import Config -from ..db import crud, schemas -from ..db.connection import Connection +from ..db import Connection, schemas router = APIRouter(prefix="/admin") @@ -32,9 +31,8 @@ async def set_config( if new_config.jwt.secret is None: new_config.jwt.secret = token_hex(32) - Connection.connect(new_config.db_engine) - Config.set(new_config) + Connection.connect(new_config.db_engine) @router.post( @@ -58,12 +56,12 @@ async def add_user( if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - crud.create_user( + schemas.User.create( db=db, user=schemas.UserCreate( name=user_name, password=user_password, + capabilities=["admin"], ), - capabilities=["admin"], crypt_context=current_config.crypt_context, ) diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 797ce96..465c1ef 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -7,11 +7,12 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from ..config import Config -from ..db import crud -from ..db.connection import Connection +from ..db import Connection, schemas router = APIRouter(prefix="/user") -SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token") +SCHEME = OAuth2PasswordBearer( + tokenUrl=f".{router.prefix}/token", +) class Token(BaseModel): @@ -37,7 +38,7 @@ def create_access_token( @router.post("/auth", response_model=Token) -async def login_for_access_token( +async def login( form_data: OAuth2PasswordRequestForm = Depends(), config: Config = Depends(Config.get), db: Session = Depends(Connection.get), @@ -48,7 +49,7 @@ async def login_for_access_token( headers={"WWW-Authenticate": "Bearer"}, ) - user = crud.get_user(db, form_data.username) + user = schemas.User.get(db, form_data.username) if user is None: config.crypt_context.dummy_verify() raise credentials_exception