""" Common dependencies for routers. """ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from ..config import Config, Settings from ..db import Device, User, UserCapabilityType oauth2_scheme = OAuth2PasswordBearer( tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" ) class Responses: """ Just a namespace. Describes API response status codes. """ OK = { "content": None, } INSTALLED = { "description": "kiwi-vpn already installed", "content": None, } NOT_INSTALLED = { "description": "kiwi-vpn not installed", "content": None, } NEEDS_USER = { "description": "Must be logged in", "content": None, } NEEDS_ADMIN = { "description": "Must be admin", "content": None, } NEEDS_REQUESTED_USER = { "description": "Must be the requested user", "content": None, } ENTRY_EXISTS = { "description": "Entry exists in database", "content": None, } ENTRY_DOESNT_EXIST = { "description": "Entry does not exist in database", "content": None, } CANT_TARGET_SELF = { "description": "Operation can't target yourself", "content": None, } async def get_current_user( token: str = Depends(oauth2_scheme), current_config: Config | None = Depends(Config.load), ) -> User | None: """ Get the currently logged-in user from the database. """ # can't connect to an unconfigured database if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) username = await current_config.jwt.decode_token(token) return User.get(username) async def get_current_user_if_exists( current_user: User | None = Depends(get_current_user), ) -> User: """ Get the currently logged-in user if it exists. """ # fail if not requested by a user if current_user is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) return current_user async def get_current_user_if_admin( current_user: User = Depends(get_current_user_if_exists), ) -> User: """ Fail if the currently logged-in user is not an admin. """ if not current_user.can(UserCapabilityType.admin): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return current_user async def get_user_by_name( user_name: str, current_user: User = Depends(get_current_user_if_exists), ) -> User: """ Get a user by name. Works if a) the currently logged-in user is an admin, or b) if it is the requested user. """ # check if current user is admin if current_user.can(UserCapabilityType.admin): # fail if requested user doesn't exist if (user := User.get(user_name)) is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # check if current user is requested user elif current_user.name == user_name: pass # current user is neither admin nor the requested user else: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return user async def get_device_by_id( device_id: int, current_config: Config | None = Depends(Config.load), ) -> Device | None: # can't connect to an unconfigured database if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) return Device.get(device_id) async def get_device_by_id_if_editable( device: Device | None = Depends(get_device_by_id), current_user: User = Depends(get_current_user_if_exists), ) -> Device: # fail if device doesn't exist if device is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # fail if device is not owned by current user if not current_user.owns(device): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return device