from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import JWTError, jwt from pydantic import BaseModel from sqlalchemy.orm import Session from ..config import Config, JWTConfig from ..db import Connection, schemas router = APIRouter(prefix="/user") SCHEME = OAuth2PasswordBearer( tokenUrl=f".{router.prefix}/auth", ) class Token(BaseModel): access_token: str token_type: str class TokenData(BaseModel): username: str | None = None def create_access_token( data: dict, jwt_config: JWTConfig, expires_delta: timedelta | None = None, ): to_encode = data.copy() if expires_delta is None: expires_delta = timedelta(minutes=jwt_config.expiry_minutes) to_encode.update({"exp": datetime.utcnow() + expires_delta}) return jwt.encode( to_encode, jwt_config.secret, algorithm=jwt_config.hash_algorithm, ) @router.post("/auth", response_model=Token) async def login( form_data: OAuth2PasswordRequestForm = Depends(), config: Config = Depends(Config.get), db: Session = Depends(Connection.get), ): user = schemas.User.verify( db=db, name=form_data.username, password=form_data.password, crypt_context=config.crypt_context, ) if user is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) access_token = create_access_token( data={"sub": user.name}, jwt_config=config.jwt, ) return {"access_token": access_token, "token_type": "bearer"} async def dep_get_current_user( token: str = Depends(SCHEME), db: Session = Depends(Connection.get), config: Config = Depends(Config.get), ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, config.jwt.secret, algorithms=[ config.jwt.hash_algorithm]) username: str = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData(username=username) except JWTError: raise credentials_exception user = schemas.User.get(db, token_data.username) if user is None: raise credentials_exception return user @router.get("/current", response_model=schemas.User) async def get_current_user( current_user: schemas.User = Depends(dep_get_current_user), ): return current_user