Compare commits

..

No commits in common. "21b85d7cfa50060da166c98b165cc9376ebe4454" and "5b623e885c38f338df7a7b3bc1bb402d866e5a80" have entirely different histories.

7 changed files with 19 additions and 33 deletions

View file

@ -5,16 +5,7 @@ Package `db`: ORM and schemas for database content.
from .connection import Connection from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead from .user import User, UserBase, UserCreate, UserRead
from .user_capability import UserCapabilityType from .user_capability import Capability
__all__ = [ __all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate",
"Connection", "User", "UserBase", "UserCreate", "UserRead"]
"Device",
"DeviceBase",
"DeviceCreate",
"User",
"UserBase",
"UserCreate",
"UserRead",
"UserCapabilityType",
]

View file

@ -13,7 +13,7 @@ from sqlmodel import Field, Relationship, SQLModel
from ..config import Config from ..config import Config
from .connection import Connection from .connection import Connection
from .device import Device from .device import Device
from .user_capability import UserCapabilityType, UserCapability from .user_capability import Capability, UserCapability
class UserBase(SQLModel): class UserBase(SQLModel):
@ -162,7 +162,7 @@ class User(UserBase, table=True):
db.delete(self) db.delete(self)
db.commit() db.commit()
def get_capabilities(self) -> set[UserCapabilityType]: def get_capabilities(self) -> set[Capability]:
""" """
Return the capabilities of this user. Return the capabilities of this user.
""" """
@ -172,14 +172,14 @@ class User(UserBase, table=True):
for capability in self.capabilities for capability in self.capabilities
) )
def can(self, capability: UserCapabilityType) -> bool: def can(self, capability: Capability) -> bool:
""" """
Check if this user has a capability. Check if this user has a capability.
""" """
return capability in self.get_capabilities() return capability in self.get_capabilities()
def set_capabilities(self, capabilities: set[UserCapabilityType]) -> None: def set_capabilities(self, capabilities: set[Capability]) -> None:
""" """
Change the capabilities of this user. Change the capabilities of this user.
""" """

View file

@ -11,7 +11,7 @@ if TYPE_CHECKING:
from .user import User from .user import User
class UserCapabilityType(Enum): class Capability(Enum):
""" """
Allowed values for capabilities Allowed values for capabilities
""" """
@ -33,12 +33,12 @@ class UserCapabilityBase(SQLModel):
capability_name: str = Field(primary_key=True) capability_name: str = Field(primary_key=True)
@property @property
def _(self) -> UserCapabilityType: def _(self) -> Capability:
""" """
Transform into a `Capability`. Transform into a `Capability`.
""" """
return UserCapabilityType(self.capability_name) return Capability(self.capability_name)
def __repr__(self) -> str: def __repr__(self) -> str:
return self.capability_name return self.capability_name

View file

@ -13,6 +13,4 @@ main_router = APIRouter()
main_router.include_router(admin.router) main_router.include_router(admin.router)
main_router.include_router(user.router) main_router.include_router(user.router)
__all__ = [ __all__ = ["main_router"]
"main_router",
]

View file

@ -7,7 +7,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import Config, Settings from ..config import Config, Settings
from ..db import UserCapabilityType, User from ..db import Capability, User
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
@ -93,7 +93,7 @@ async def get_current_user_if_admin(
""" """
# fail if not requested by an admin # fail if not requested by an admin
if not current_user.can(UserCapabilityType.admin): if not current_user.can(Capability.admin):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -111,7 +111,7 @@ async def get_current_user_if_admin_or_self(
""" """
# fail if not requested by an admin or self # fail if not requested by an admin or self
if not (current_user.can(UserCapabilityType.admin) if not (current_user.can(Capability.admin)
or current_user.name == user_name): or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)

View file

@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import select from sqlmodel import select
from ..config import Config from ..config import Config
from ..db import Connection, User, UserCapabilityType, UserCreate from ..db import Capability, Connection, User, UserCreate
from ._common import Responses, get_current_user_if_admin from ._common import Responses, get_current_user_if_admin
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@ -63,10 +63,7 @@ async def create_initial_admin(
# create an administrative user # create an administrative user
new_user = User.create(**admin_user.dict()) new_user = User.create(**admin_user.dict())
new_user.set_capabilities([ new_user.set_capabilities([Capability.login, Capability.admin])
UserCapabilityType.login,
UserCapabilityType.admin,
])
new_user.update() new_user.update()

View file

@ -7,7 +7,7 @@ from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from ..config import Config from ..config import Config
from ..db import UserCapabilityType, User, UserCreate, UserRead from ..db import Capability, User, UserCreate, UserRead
from ._common import Responses, get_current_user, get_current_user_if_admin from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -134,7 +134,7 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapabilityType], capabilities: list[Capability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """
@ -162,7 +162,7 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapabilityType], capabilities: list[Capability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """