diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index 44d1c38..81f74ee 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -85,19 +85,16 @@ async def get_current_user_if_exists( return current_user -async def get_current_user_if_admin( +async def current_user_is_admin( current_user: User = Depends(get_current_user_if_exists), ) -> User: """ - Get the currently logged-in user if it is an admin. + Fail if the currently logged-in user is not an admin. """ - # fail if not requested by an admin if not current_user.can(UserCapabilityType.admin): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - return current_user - async def get_user_by_name( user_name: str, diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 81978b9..b43f1ae 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -8,7 +8,7 @@ from sqlmodel import select from ..config import Config from ..db import Connection, User, UserCapabilityType, UserCreate -from ._common import Responses, get_current_user_if_admin +from ._common import Responses, current_user_is_admin router = APIRouter(prefix="/admin", tags=["admin"]) @@ -79,7 +79,7 @@ async def create_initial_admin( ) async def set_config( config: Config, - _: User = Depends(get_current_user_if_admin), + _: User = Depends(current_user_is_admin), ): """ PUT ./config: Edit `kiwi-vpn` main config. diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 18222dd..92ce8ec 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -8,8 +8,8 @@ from pydantic import BaseModel from ..config import Config from ..db import User, UserCapabilityType, UserCreate, UserRead -from ._common import (Responses, get_current_user, get_current_user_if_admin, - get_user_by_name) +from ._common import (Responses, current_user_is_admin, + get_current_user_if_exists, get_user_by_name) router = APIRouter(prefix="/user", tags=["user"]) @@ -59,7 +59,7 @@ async def login( @router.get("/current", response_model=UserRead) async def get_current_user( - current_user: User | None = Depends(get_current_user), + current_user: User = Depends(get_current_user_if_exists), ): """ GET ./current: Respond with the currently logged-in user. @@ -81,7 +81,7 @@ async def get_current_user( ) async def add_user( user: UserCreate, - _: User = Depends(get_current_user_if_admin), + _: User = Depends(current_user_is_admin), ) -> User: """ POST ./: Create a new user in the database. @@ -113,7 +113,7 @@ async def add_user( response_model=User, ) async def remove_user( - _: User = Depends(get_current_user_if_admin), + _: User = Depends(current_user_is_admin), user: User = Depends(get_user_by_name), ): """ @@ -135,7 +135,7 @@ async def remove_user( ) async def extend_capabilities( capabilities: list[UserCapabilityType], - _: User = Depends(get_current_user_if_admin), + _: User = Depends(current_user_is_admin), user: User = Depends(get_user_by_name), ): """ @@ -158,7 +158,7 @@ async def extend_capabilities( ) async def remove_capabilities( capabilities: list[UserCapabilityType], - _: User = Depends(get_current_user_if_admin), + _: User = Depends(current_user_is_admin), user: User = Depends(get_user_by_name), ): """