diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index 0c08753..d4e01c1 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from ..config import Config from ..db import Connection -from ..db.schemas import User, UserCapability +from ..db.schemas import User oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") @@ -65,19 +65,47 @@ async def get_current_user( return user -async def get_current_admin_user( +async def get_current_user_if_admin( current_config: Config | None = Depends(Config.load), current_user: User | None = Depends(get_current_user), ) -> User: """ - Check if the currently logged-in user is an admin. + Get the currently logged-in user if it is an admin. """ # fail if not installed if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - # fail if not requested by an admin - if (current_user is None - or UserCapability.admin not in current_user.capabilities): + # fail if not requested by a user + if current_user is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + # fail if not requested by an admin + if not current_user.is_admin(): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + +async def get_current_user_if_admin_or_self( + user_name: str, + current_config: Config | None = Depends(Config.load), + current_user: User | None = Depends(get_current_user), +) -> User: + """ + Get the currently logged-in user. + + Fails a) if the currently logged-in user is not the requested user, + and b) if it is not an admin. + """ + + # fail if not installed + if current_config is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + # fail if not requested by a user + if current_user is None: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + + # fail if not requested by an admin or self + if not (current_user.is_admin() or current_user.name == user_name): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 62838a2..b66b19e 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import Session from ..config import Config from ..db import Connection from ..db.schemas import User, UserCapability, UserCreate -from ._common import Responses, get_current_admin_user, get_current_user +from ._common import Responses, get_current_user, get_current_user_if_admin router = APIRouter(prefix="/user") @@ -82,7 +82,7 @@ async def get_current_user( async def add_user( user: UserCreate, current_config: Config | None = Depends(Config.load), - _: User = Depends(get_current_admin_user), + _: User = Depends(get_current_user_if_admin), db: Session | None = Depends(Connection.get), ): """ @@ -112,12 +112,11 @@ async def add_user( status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, }, - response_model=User, ) async def extend_capabilities( user_name: str, capabilities: list[UserCapability], - _: User = Depends(get_current_admin_user), + _: User = Depends(get_current_user_if_admin), db: Session | None = Depends(Connection.get), ): """ @@ -133,9 +132,6 @@ async def extend_capabilities( user.capabilities.extend(capabilities) user.update(db) - # return the modified user - return user - @router.delete( "/{user_name}/capabilities", @@ -145,16 +141,15 @@ async def extend_capabilities( status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, }, - response_model=User, ) async def remove_capabilities( user_name: str, capabilities: list[UserCapability], - _: User | None = Depends(get_current_admin_user), + _: User = Depends(get_current_user_if_admin), db: Session | None = Depends(Connection.get), ): """ - DELETE ./{user_name}/capabilities: Add capabilities to a user. + DELETE ./{user_name}/capabilities: Remove capabilities from a user. """ # get and change the user @@ -167,6 +162,3 @@ async def remove_capabilities( user.capabilities.remove(capability) user.update(db) - - # return the modified user - return user