diff --git a/api/kiwi_vpn_api/db/user.py b/api/kiwi_vpn_api/db/user.py index 5b8668c..8e6e5f7 100644 --- a/api/kiwi_vpn_api/db/user.py +++ b/api/kiwi_vpn_api/db/user.py @@ -203,6 +203,20 @@ class User(UserBase, table=True): ) for capability in capabilities ] + def can_edit( + self, + user: User, + ) -> bool: + """ + Check if this user can edit another user. + """ + + return ( + user.name == self.name + # admin can edit everything + or self.can(UserCapabilityType.admin) + ) + def owns( self, device: Device, diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index fd380c5..092bb7e 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -105,27 +105,33 @@ async def get_current_user_if_admin( async def get_user_by_name( user_name: str, + current_config: Config | None = Depends(Config.load), +) -> User | None: + """ + Get a user by name. + """ + + # can't connect to an unconfigured database + if current_config is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + return User.get(user_name) + + +async def get_user_by_name_if_editable( + user: User | None = Depends(get_user_by_name), current_user: User = Depends(get_current_user_if_exists), ) -> User: """ - Get a user by name. - - Works if a) the currently logged-in user is an admin, - or b) if it is the requested user. + Get a user by name if it can be edited by the current user. """ - # check if current user is admin - if current_user.can(UserCapabilityType.admin): - # fail if requested user doesn't exist - if (user := User.get(user_name)) is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + # fail if user doesn't exist + if user is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) - # check if current user is requested user - elif current_user.name == user_name: - pass - - # current user is neither admin nor the requested user - else: + # fail if user isn't editable by the current user + if not current_user.can_edit(user): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return user