diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 945e367..10e1b97 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -2,8 +2,10 @@ from __future__ import annotations import functools import json +from datetime import datetime, timedelta from enum import Enum +from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field @@ -45,6 +47,53 @@ class JWTConfig(BaseModel): hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 + async def encode( + self, + username: str, + expiry_minutes: int | None = None, + ) -> str: + if expiry_minutes is None: + expiry_minutes = self.expiry_minutes + + return jwt.encode( + { + "sub": username, + "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), + }, + self.secret, + algorithm=self.hash_algorithm, + ) + + async def decode( + self, + token: str, + ) -> str | None: + # decode JWT token + try: + payload = jwt.decode( + token, + self.secret, + algorithms=[self.hash_algorithm], + ) + + except JWTError: + return None + + # check expiry + expiry = payload.get("exp") + if expiry is None: + return None + + if datetime.fromtimestamp(expiry) < datetime.utcnow(): + return None + + # get username + username = payload.get("sub") + if username is None: + return None + + return username + class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index d49b8db..777653b 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -1,12 +1,10 @@ -from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm -from jose import JWTError, jwt from pydantic import BaseModel from sqlalchemy.orm import Session -from ..config import Config, JWTConfig +from ..config import Config from ..db import Connection, schemas router = APIRouter(prefix="/user") @@ -20,28 +18,6 @@ class Token(BaseModel): token_type: str -class TokenData(BaseModel): - username: str | None = None - - -def create_access_token( - data: dict, - jwt_config: JWTConfig, - expires_delta: timedelta | None = None, -): - to_encode = data.copy() - - if expires_delta is None: - expires_delta = timedelta(minutes=jwt_config.expiry_minutes) - - to_encode.update({"exp": datetime.utcnow() + expires_delta}) - return jwt.encode( - to_encode, - jwt_config.secret, - algorithm=jwt_config.hash_algorithm, - ) - - @router.post("/auth", response_model=Token) async def login( form_data: OAuth2PasswordRequestForm = Depends(), @@ -62,11 +38,7 @@ async def login( headers={"WWW-Authenticate": "Bearer"}, ) - access_token = create_access_token( - data={"sub": user.name}, - jwt_config=config.jwt, - ) - + access_token = await config.jwt.encode(user.name) return {"access_token": access_token, "token_type": "bearer"} @@ -75,25 +47,16 @@ async def dep_get_current_user( db: Session = Depends(Connection.get), config: Config = Depends(Config.get), ): - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, config.jwt.secret, algorithms=[ - config.jwt.hash_algorithm]) - username: str = payload.get("sub") - if username is None: - raise credentials_exception - token_data = TokenData(username=username) - except JWTError: - raise credentials_exception - - user = schemas.User.get(db, token_data.username) + username = await config.jwt.decode(token) + user = schemas.User.get(db, username) if user is None: - raise credentials_exception + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user