JWT handling

This commit is contained in:
Jörn-Michael Miehe 2022-03-19 02:22:49 +00:00
parent affff321ab
commit 2e9093caaf
2 changed files with 59 additions and 47 deletions

View file

@ -2,8 +2,10 @@ from __future__ import annotations
import functools import functools
import json import json
from datetime import datetime, timedelta
from enum import Enum from enum import Enum
from jose import JWTError, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field from pydantic import BaseModel, BaseSettings, Field
@ -45,6 +47,53 @@ class JWTConfig(BaseModel):
hash_algorithm: str = ALGORITHMS.HS256 hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30 expiry_minutes: int = 30
async def encode(
self,
username: str,
expiry_minutes: int | None = None,
) -> str:
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
return jwt.encode(
{
"sub": username,
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
},
self.secret,
algorithm=self.hash_algorithm,
)
async def decode(
self,
token: str,
) -> str | None:
# decode JWT token
try:
payload = jwt.decode(
token,
self.secret,
algorithms=[self.hash_algorithm],
)
except JWTError:
return None
# check expiry
expiry = payload.get("exp")
if expiry is None:
return None
if datetime.fromtimestamp(expiry) < datetime.utcnow():
return None
# get username
username = payload.get("sub")
if username is None:
return None
return username
class CryptoConfig(BaseModel): class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]

View file

@ -1,12 +1,10 @@
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import Config, JWTConfig from ..config import Config
from ..db import Connection, schemas from ..db import Connection, schemas
router = APIRouter(prefix="/user") router = APIRouter(prefix="/user")
@ -20,28 +18,6 @@ class Token(BaseModel):
token_type: str token_type: str
class TokenData(BaseModel):
username: str | None = None
def create_access_token(
data: dict,
jwt_config: JWTConfig,
expires_delta: timedelta | None = None,
):
to_encode = data.copy()
if expires_delta is None:
expires_delta = timedelta(minutes=jwt_config.expiry_minutes)
to_encode.update({"exp": datetime.utcnow() + expires_delta})
return jwt.encode(
to_encode,
jwt_config.secret,
algorithm=jwt_config.hash_algorithm,
)
@router.post("/auth", response_model=Token) @router.post("/auth", response_model=Token)
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
@ -62,11 +38,7 @@ async def login(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token = create_access_token( access_token = await config.jwt.encode(user.name)
data={"sub": user.name},
jwt_config=config.jwt,
)
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@ -75,25 +47,16 @@ async def dep_get_current_user(
db: Session = Depends(Connection.get), db: Session = Depends(Connection.get),
config: Config = Depends(Config.get), config: Config = Depends(Config.get),
): ):
credentials_exception = HTTPException( username = await config.jwt.decode(token)
status_code=status.HTTP_401_UNAUTHORIZED, user = schemas.User.get(db, username)
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, config.jwt.secret, algorithms=[
config.jwt.hash_algorithm])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = schemas.User.get(db, token_data.username)
if user is None: if user is None:
raise credentials_exception raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user return user