diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index 528a83a..91def45 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -89,7 +89,7 @@ async def get_current_user_if_exists( return current_user -async def current_user_is_admin( +async def get_current_user_if_admin( current_user: User = Depends(get_current_user_if_exists), ) -> User: """ @@ -99,6 +99,8 @@ async def current_user_is_admin( if not current_user.can(UserCapabilityType.admin): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + return current_user + async def get_user_by_name( user_name: str, diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 609edb7..29c442f 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -8,7 +8,7 @@ from sqlmodel import select from ..config import Config from ..db import Connection, User, UserCapabilityType, UserCreate -from ._common import Responses, current_user_is_admin +from ._common import Responses, get_current_user_if_admin router = APIRouter(prefix="/admin", tags=["admin"]) @@ -79,7 +79,7 @@ async def create_initial_admin( ) async def set_config( config: Config, - _: User = Depends(current_user_is_admin), + _: User = Depends(get_current_user_if_admin), ): """ PUT ./config: Edit `kiwi-vpn` main config. diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 5a07bd5..8bd9b74 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from ..config import Config from ..db import User, UserCapabilityType, UserCreate, UserRead -from ._common import (Responses, current_user_is_admin, +from ._common import (Responses, get_current_user_if_admin, get_current_user_if_exists, get_user_by_name) router = APIRouter(prefix="/user", tags=["user"]) @@ -81,7 +81,7 @@ async def get_current_user( ) async def add_user( user: UserCreate, - _: User = Depends(current_user_is_admin), + _: User = Depends(get_current_user_if_admin), ) -> User: """ POST ./: Create a new user in the database. @@ -114,7 +114,7 @@ async def add_user( response_model=User, ) async def remove_user( - _: User = Depends(current_user_is_admin), + current_user: User = Depends(get_current_user_if_admin), user: User = Depends(get_user_by_name), ): """ @@ -140,7 +140,7 @@ async def remove_user( ) async def extend_capabilities( capabilities: list[UserCapabilityType], - _: User = Depends(current_user_is_admin), + _: User = Depends(get_current_user_if_admin), user: User = Depends(get_user_by_name), ): """ @@ -163,7 +163,7 @@ async def extend_capabilities( ) async def remove_capabilities( capabilities: list[UserCapabilityType], - _: User = Depends(current_user_is_admin), + _: User = Depends(get_current_user_if_admin), user: User = Depends(get_user_by_name), ): """