Capability -> UserCapabilityType

This commit is contained in:
Jörn-Michael Miehe 2022-03-28 21:41:49 +00:00
parent 5b623e885c
commit 499c97a28a
6 changed files with 30 additions and 18 deletions

View file

@ -5,7 +5,16 @@ Package `db`: ORM and schemas for database content.
from .connection import Connection from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead from .user import User, UserBase, UserCreate, UserRead
from .user_capability import Capability from .user_capability import UserCapabilityType
__all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate", __all__ = [
"User", "UserBase", "UserCreate", "UserRead"] "Connection",
"Device",
"DeviceBase",
"DeviceCreate",
"User",
"UserBase",
"UserCreate",
"UserRead",
"UserCapabilityType",
]

View file

@ -13,7 +13,7 @@ from sqlmodel import Field, Relationship, SQLModel
from ..config import Config from ..config import Config
from .connection import Connection from .connection import Connection
from .device import Device from .device import Device
from .user_capability import Capability, UserCapability from .user_capability import UserCapabilityType, UserCapability
class UserBase(SQLModel): class UserBase(SQLModel):
@ -162,7 +162,7 @@ class User(UserBase, table=True):
db.delete(self) db.delete(self)
db.commit() db.commit()
def get_capabilities(self) -> set[Capability]: def get_capabilities(self) -> set[UserCapabilityType]:
""" """
Return the capabilities of this user. Return the capabilities of this user.
""" """
@ -172,14 +172,14 @@ class User(UserBase, table=True):
for capability in self.capabilities for capability in self.capabilities
) )
def can(self, capability: Capability) -> bool: def can(self, capability: UserCapabilityType) -> bool:
""" """
Check if this user has a capability. Check if this user has a capability.
""" """
return capability in self.get_capabilities() return capability in self.get_capabilities()
def set_capabilities(self, capabilities: set[Capability]) -> None: def set_capabilities(self, capabilities: set[UserCapabilityType]) -> None:
""" """
Change the capabilities of this user. Change the capabilities of this user.
""" """

View file

@ -11,7 +11,7 @@ if TYPE_CHECKING:
from .user import User from .user import User
class Capability(Enum): class UserCapabilityType(Enum):
""" """
Allowed values for capabilities Allowed values for capabilities
""" """
@ -33,12 +33,12 @@ class UserCapabilityBase(SQLModel):
capability_name: str = Field(primary_key=True) capability_name: str = Field(primary_key=True)
@property @property
def _(self) -> Capability: def _(self) -> UserCapabilityType:
""" """
Transform into a `Capability`. Transform into a `Capability`.
""" """
return Capability(self.capability_name) return UserCapabilityType(self.capability_name)
def __repr__(self) -> str: def __repr__(self) -> str:
return self.capability_name return self.capability_name

View file

@ -7,7 +7,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import Config, Settings from ..config import Config, Settings
from ..db import Capability, User from ..db import UserCapabilityType, User
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
@ -93,7 +93,7 @@ async def get_current_user_if_admin(
""" """
# fail if not requested by an admin # fail if not requested by an admin
if not current_user.can(Capability.admin): if not current_user.can(UserCapabilityType.admin):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -111,7 +111,7 @@ async def get_current_user_if_admin_or_self(
""" """
# fail if not requested by an admin or self # fail if not requested by an admin or self
if not (current_user.can(Capability.admin) if not (current_user.can(UserCapabilityType.admin)
or current_user.name == user_name): or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View file

@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import select from sqlmodel import select
from ..config import Config from ..config import Config
from ..db import Capability, Connection, User, UserCreate from ..db import Connection, User, UserCapabilityType, UserCreate
from ._common import Responses, get_current_user_if_admin from ._common import Responses, get_current_user_if_admin
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@ -63,7 +63,10 @@ async def create_initial_admin(
# create an administrative user # create an administrative user
new_user = User.create(**admin_user.dict()) new_user = User.create(**admin_user.dict())
new_user.set_capabilities([Capability.login, Capability.admin]) new_user.set_capabilities([
UserCapabilityType.login,
UserCapabilityType.admin,
])
new_user.update() new_user.update()

View file

@ -7,7 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from ..config import Config from ..config import Config
from ..db import Capability, User, UserCreate, UserRead from ..db import UserCapabilityType, User, UserCreate, UserRead
from ._common import Responses, get_current_user, get_current_user_if_admin from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -134,7 +134,7 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[Capability], capabilities: list[UserCapabilityType],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """
@ -162,7 +162,7 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[Capability], capabilities: list[UserCapabilityType],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """