Compare commits

...

3 commits

5 changed files with 114 additions and 38 deletions

View file

@ -174,20 +174,6 @@ class User(UserBase, table=True):
for capability in self.capabilities for capability in self.capabilities
) )
def can(
self,
capability: UserCapabilityType,
) -> bool:
"""
Check if this user has a capability.
"""
return (
capability in self.get_capabilities()
# admin can do everything
or UserCapabilityType.admin in self.get_capabilities()
)
def set_capabilities( def set_capabilities(
self, self,
capabilities: Sequence[UserCapabilityType], capabilities: Sequence[UserCapabilityType],
@ -203,6 +189,88 @@ class User(UserBase, table=True):
) for capability in capabilities ) for capability in capabilities
] ]
def _can(
self,
capability: UserCapabilityType,
) -> bool:
"""
Check if this user has a capability.
"""
return (
capability in self.get_capabilities()
# admin can do everything
or UserCapabilityType.admin in self.get_capabilities()
)
def can_edit(
self,
user: User,
) -> bool:
"""
Check if this user can edit another user.
"""
return (
user.name == self.name
# admin can edit everything
or self._can(UserCapabilityType.admin)
)
def is_admin(
self,
) -> bool:
"""
Check if this user is an admin.
"""
# is admin with "admin" capability
return self._can(UserCapabilityType.admin)
def can_login(
self,
) -> bool:
"""
Check if this user can log in.
"""
return (
# can login with "login" capability
self._can(UserCapabilityType.login)
# admins can always login
or self.is_admin()
)
def can_be_edited_by(
self,
user: User,
) -> bool:
"""
Check if this user can be edited by another user.
"""
return (
# user can edit itself
self.name == user.name
# admin can edit every user
or user._can(UserCapabilityType.admin)
)
def can_be_deleted_by(
self,
user: User,
) -> bool:
"""
Check if this user can be deleted by another user.
"""
return (
# only admin can delete users
user._can(UserCapabilityType.admin)
# even admin cannot delete itself
and self.name != user.name
)
def owns( def owns(
self, self,
device: Device, device: Device,
@ -214,5 +282,5 @@ class User(UserBase, table=True):
return ( return (
device.owner_name == self.name device.owner_name == self.name
# admin owns everything # admin owns everything
or self.can(UserCapabilityType.admin) or self._can(UserCapabilityType.admin)
) )

View file

@ -13,7 +13,7 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db import Connection, User from .db import Connection, User, UserRead
from .routers import main_router from .routers import main_router
app = FastAPI( app = FastAPI(
@ -43,7 +43,7 @@ async def on_startup() -> None:
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)
# some testing # some testing
print(User.get("admin")) print(UserRead.from_orm(User.get("admin")))
print(User.get("nonexistent")) print(User.get("nonexistent"))

View file

@ -7,7 +7,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import Config, Settings from ..config import Config, Settings
from ..db import Device, User, UserCapabilityType from ..db import Device, User
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
@ -84,7 +84,8 @@ async def get_current_user_if_exists(
# fail if not requested by a user # fail if not requested by a user
if current_user is None: if current_user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # don't use error 404 here: possible user enumeration
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -96,7 +97,7 @@ async def get_current_user_if_admin(
Fail if the currently logged-in user is not an admin. Fail if the currently logged-in user is not an admin.
""" """
if not current_user.can(UserCapabilityType.admin): if not current_user.is_admin():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -104,27 +105,33 @@ async def get_current_user_if_admin(
async def get_user_by_name( async def get_user_by_name(
user_name: str, user_name: str,
current_config: Config | None = Depends(Config.load),
) -> User | None:
"""
Get a user by name.
"""
# can't connect to an unconfigured database
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return User.get(user_name)
async def get_user_by_name_if_editable(
user: User | None = Depends(get_user_by_name),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
Get a user by name. Get a user by name if it can be edited by the current user.
Works if a) the currently logged-in user is an admin,
or b) if it is the requested user.
""" """
# check if current user is admin # fail if user doesn't exist
if current_user.can(UserCapabilityType.admin): if user is None:
# fail if requested user doesn't exist
if (user := User.get(user_name)) is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# check if current user is requested user # fail if user isn't editable by the current user
elif current_user.name == user_name: if not current_user.can_edit(user):
pass
# current user is neither admin nor the requested user
else:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return user return user

View file

@ -5,7 +5,8 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..db import Device, DeviceCreate, DeviceRead, User from ..db import Device, DeviceCreate, DeviceRead, User
from ._common import Responses, get_device_by_id_if_editable, get_user_by_name from ._common import (Responses, get_device_by_id_if_editable,
get_user_by_name_if_editable)
router = APIRouter(prefix="/device", tags=["device"]) router = APIRouter(prefix="/device", tags=["device"])
@ -24,7 +25,7 @@ router = APIRouter(prefix="/device", tags=["device"])
) )
async def add_device( async def add_device(
device: DeviceCreate, device: DeviceCreate,
user: User = Depends(get_user_by_name), user: User = Depends(get_user_by_name_if_editable),
) -> Device: ) -> Device:
""" """
POST ./: Create a new device in the database. POST ./: Create a new device in the database.

View file

@ -48,7 +48,7 @@ async def login(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
if not user.can(UserCapabilityType.login): if not user.can_login():
# user cannot login # user cannot login
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)