diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemas.py index 29b945b..dbfb393 100644 --- a/api/kiwi_vpn_api/db/schemas.py +++ b/api/kiwi_vpn_api/db/schemas.py @@ -31,9 +31,14 @@ class UserBase(BaseModel): @validator("capabilities", pre=True) @classmethod - def caps_from_orm(cls, value: list[models.UserCapability]) -> list[str]: + def unify_capabilities( + cls, + value: list[models.UserCapability | str] + ) -> list[str]: return [ capability.capability + if isinstance(capability, models.UserCapability) + else str(capability) for capability in value ] @@ -53,12 +58,37 @@ class User(UserBase): cls, db: Session, name: str, - ) -> User: + ) -> User | None: user = (db .query(models.User) .filter(models.User.name == name) .first()) + if user is None: + return None + + return cls.from_orm(user) + + @classmethod + def verify( + cls, + db: Session, + name: str, + password: str, + crypt_context: CryptContext, + ) -> User | None: + user = (db + .query(models.User) + .filter(models.User.name == name) + .first()) + + if user is None: + crypt_context.dummy_verify() + return None + + if not crypt_context.verify(password, user.password): + return None + return cls.from_orm(user) @classmethod diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index b798f81..dd4ee38 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -38,9 +38,10 @@ async def on_startup() -> None: if (current_config := await Config.get()) is not None: Connection.connect(current_config.db_engine) + # some testing async for db in Connection.get(): - user = schemas.User.get(db, "admin") - print(str(user)) + print(schemas.User.get(db, "admin")) + print(schemas.User.get(db, "nonexistent")) def main() -> None: diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 465c1ef..d49b8db 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -2,16 +2,16 @@ from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import jwt +from jose import JWTError, jwt from pydantic import BaseModel from sqlalchemy.orm import Session -from ..config import Config +from ..config import Config, JWTConfig from ..db import Connection, schemas router = APIRouter(prefix="/user") SCHEME = OAuth2PasswordBearer( - tokenUrl=f".{router.prefix}/token", + tokenUrl=f".{router.prefix}/auth", ) @@ -20,20 +20,25 @@ class Token(BaseModel): token_type: str +class TokenData(BaseModel): + username: str | None = None + + def create_access_token( data: dict, + jwt_config: JWTConfig, expires_delta: timedelta | None = None, - config: Config = Depends(Config.get), ): to_encode = data.copy() + if expires_delta is None: - expires_delta = timedelta(minutes=15) + expires_delta = timedelta(minutes=jwt_config.expiry_minutes) to_encode.update({"exp": datetime.utcnow() + expires_delta}) return jwt.encode( to_encode, - config.jwt.secret, - algorithm=config.jwt.hash_algorithm, + jwt_config.secret, + algorithm=jwt_config.hash_algorithm, ) @@ -42,23 +47,58 @@ async def login( form_data: OAuth2PasswordRequestForm = Depends(), config: Config = Depends(Config.get), db: Session = Depends(Connection.get), +): + user = schemas.User.verify( + db=db, + name=form_data.username, + password=form_data.password, + crypt_context=config.crypt_context, + ) + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + access_token = create_access_token( + data={"sub": user.name}, + jwt_config=config.jwt, + ) + + return {"access_token": access_token, "token_type": "bearer"} + + +async def dep_get_current_user( + token: str = Depends(SCHEME), + db: Session = Depends(Connection.get), + config: Config = Depends(Config.get), ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) + try: + payload = jwt.decode(token, config.jwt.secret, algorithms=[ + config.jwt.hash_algorithm]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except JWTError: + raise credentials_exception + + user = schemas.User.get(db, token_data.username) - user = schemas.User.get(db, form_data.username) if user is None: - config.crypt_context.dummy_verify() raise credentials_exception + return user - if not user.verify(form_data.password): - raise credentials_exception - access_token = create_access_token( - data={"sub": user.name}, - expires_delta=timedelta(minutes=config.jwt.expiry_minutes), - ) - return {"access_token": access_token, "token_type": "bearer"} +@router.get("/current", response_model=schemas.User) +async def get_current_user( + current_user: schemas.User = Depends(dep_get_current_user), +): + return current_user