This commit is contained in:
Jörn-Michael Miehe 2022-03-23 01:14:02 +00:00
parent ae73c8ff70
commit 2ed09a5b3f
2 changed files with 39 additions and 19 deletions

View file

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Connection
from ..db.schemas import User, UserCapability from ..db.schemas import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
@ -65,19 +65,47 @@ async def get_current_user(
return user return user
async def get_current_admin_user( async def get_current_user_if_admin(
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
) -> User: ) -> User:
""" """
Check if the currently logged-in user is an admin. Get the currently logged-in user if it is an admin.
""" """
# fail if not installed # fail if not installed
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin # fail if not requested by a user
if (current_user is None if current_user is None:
or UserCapability.admin not in current_user.capabilities): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# fail if not requested by an admin
if not current_user.is_admin():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
async def get_current_user_if_admin_or_self(
user_name: str,
current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user),
) -> User:
"""
Get the currently logged-in user.
Fails a) if the currently logged-in user is not the requested user,
and b) if it is not an admin.
"""
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by a user
if current_user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# fail if not requested by an admin or self
if not (current_user.is_admin() or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View file

@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_admin_user, get_current_user from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user") router = APIRouter(prefix="/user")
@ -82,7 +82,7 @@ async def get_current_user(
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
_: User = Depends(get_current_admin_user), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
): ):
""" """
@ -112,12 +112,11 @@ async def add_user(
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
}, },
response_model=User,
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[UserCapability],
_: User = Depends(get_current_admin_user), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
): ):
""" """
@ -133,9 +132,6 @@ async def extend_capabilities(
user.capabilities.extend(capabilities) user.capabilities.extend(capabilities)
user.update(db) user.update(db)
# return the modified user
return user
@router.delete( @router.delete(
"/{user_name}/capabilities", "/{user_name}/capabilities",
@ -145,16 +141,15 @@ async def extend_capabilities(
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
}, },
response_model=User,
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[UserCapability],
_: User | None = Depends(get_current_admin_user), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}/capabilities: Add capabilities to a user. DELETE ./{user_name}/capabilities: Remove capabilities from a user.
""" """
# get and change the user # get and change the user
@ -167,6 +162,3 @@ async def remove_capabilities(
user.capabilities.remove(capability) user.capabilities.remove(capability)
user.update(db) user.update(db)
# return the modified user
return user