From 9b5a98e0c07d755fec26a82af8f1bf6a8c31af5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rn-Michael=20Miehe?= <40151420+ldericher@users.noreply.github.com> Date: Wed, 30 Mar 2022 01:51:43 +0000 Subject: [PATCH] rework User methods --- api/kiwi_vpn_api/db/tag.py | 14 +++++++- api/kiwi_vpn_api/db/user.py | 50 ++++++++++++++++++++--------- api/kiwi_vpn_api/routers/_common.py | 2 +- api/kiwi_vpn_api/routers/admin.py | 2 +- api/kiwi_vpn_api/routers/device.py | 4 +-- api/kiwi_vpn_api/routers/user.py | 10 +++--- 6 files changed, 56 insertions(+), 26 deletions(-) diff --git a/api/kiwi_vpn_api/db/tag.py b/api/kiwi_vpn_api/db/tag.py index b29622b..403f124 100644 --- a/api/kiwi_vpn_api/db/tag.py +++ b/api/kiwi_vpn_api/db/tag.py @@ -2,6 +2,8 @@ Python representation of `tag` table. """ +from __future__ import annotations + from enum import Enum from typing import TYPE_CHECKING @@ -24,6 +26,16 @@ class TagValue(Enum): def __repr__(self) -> str: return self.value + def _(self, user: User) -> Tag: + """ + Transform into a `Tag`. + """ + + return Tag( + user=user, + tag_value=self.value, + ) + class TagBase(SQLModel): """ @@ -51,6 +63,6 @@ class Tag(TagBase, table=True): user_name: str = Field(primary_key=True, foreign_key="user.name") - user: "User" = Relationship( + user: User = Relationship( back_populates="tags", ) diff --git a/api/kiwi_vpn_api/db/user.py b/api/kiwi_vpn_api/db/user.py index 7e0164d..ad1d51f 100644 --- a/api/kiwi_vpn_api/db/user.py +++ b/api/kiwi_vpn_api/db/user.py @@ -4,7 +4,7 @@ Python representation of `user` table. from __future__ import annotations -from typing import Any, Sequence +from typing import Any, Iterable, Sequence from pydantic import root_validator from sqlalchemy.exc import IntegrityError @@ -143,7 +143,7 @@ class User(UserBase, table=True): # password hash mismatch return None - if TagValue.login in user.get_tags(): + if user.has_tag(TagValue.login): # no login permission return None @@ -168,32 +168,50 @@ class User(UserBase, table=True): db.delete(self) db.commit() - def get_tags(self) -> set[TagValue]: + def _get_tags(self) -> Iterable[TagValue]: """ Return the tags of this user. """ - return set( + return ( tag._ for tag in self.tags ) - def set_tags( + def has_tag(self, tag: TagValue) -> bool: + """ + Check if this user has a tag. + """ + + return tag in self._get_tags + + def add_tags( self, tags: Sequence[TagValue], ) -> None: """ - Change the tags of this user. + Add tags to this user. """ self.tags = [ - Tag( - user_name=self.name, - tag_value=tag.value, - ) for tag in tags + tag._(self) + for tag in (set(self._get_tags()) | set(tags)) ] - def may_edit( + def remove_tags( + self, + tags: Sequence[TagValue], + ) -> None: + """ + remove tags from this user. + """ + + self.tags = [ + tag._(self) + for tag in (set(self._get_tags()) - set(tags)) + ] + + def can_edit( self, target: User | Device, ) -> bool: @@ -202,7 +220,7 @@ class User(UserBase, table=True): """ # admin can "edit" everything - if TagValue.admin in self.get_tags(): + if self.has_tag(TagValue.admin): return True # user can "edit" itself @@ -212,7 +230,7 @@ class User(UserBase, table=True): # user can edit its owned devices return target.owner == self - def may_admin( + def can_admin( self, target: User | Device, ) -> bool: @@ -221,7 +239,7 @@ class User(UserBase, table=True): """ # only admin can "admin" anything - if TagValue.admin not in self.get_tags(): + if not self.has_tag(TagValue.admin): return False # admin canot "admin itself"! @@ -231,7 +249,7 @@ class User(UserBase, table=True): # admin can "admin" everything else return True - def may_create( + def can_create( self, target: type, owner: User | None = None, @@ -245,7 +263,7 @@ class User(UserBase, table=True): return False # admin can "create" everything - if TagValue.admin in self.get_tags(): + if self.has_tag(TagValue.admin): return True # user can only create devices for itself diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index 0732135..1d7bff7 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -88,7 +88,7 @@ async def get_current_user_if_admin( Fail if the currently logged-in user is not an admin. """ - if TagValue.admin not in current_user.get_tags(): + if current_user.has_tag(TagValue.admin): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return current_user diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index e5e3941..7727bad 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -64,7 +64,7 @@ async def create_initial_admin( # create an administrative user new_user = User.create(user=admin_user) - new_user.set_tags([TagValue.admin]) + new_user.add_tags([TagValue.admin]) new_user.update() diff --git a/api/kiwi_vpn_api/routers/device.py b/api/kiwi_vpn_api/routers/device.py index 957e31d..6611cbf 100644 --- a/api/kiwi_vpn_api/routers/device.py +++ b/api/kiwi_vpn_api/routers/device.py @@ -33,7 +33,7 @@ async def add_device( """ # check permission - if not current_user.may_create(Device, owner): + if not current_user.can_create(Device, owner): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) # create the new device @@ -70,7 +70,7 @@ async def remove_device( """ # check permission - if not current_user.may_edit(device): + if not current_user.can_edit(device): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) # delete device diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 2572d40..0529b16 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -8,8 +8,8 @@ from pydantic import BaseModel from ..config import Config from ..db import TagValue, User, UserCreate, UserRead -from ._common import (Responses, get_current_user_if_admin, - get_current_user, get_user_by_name) +from ._common import (Responses, get_current_user, get_current_user_if_admin, + get_user_by_name) router = APIRouter(prefix="/user", tags=["user"]) @@ -90,7 +90,7 @@ async def add_user( if new_user is None: raise HTTPException(status_code=status.HTTP_409_CONFLICT) - new_user.set_tags([TagValue.login]) + new_user.add_tags([TagValue.login]) new_user.update() # return the created user on success @@ -143,7 +143,7 @@ async def extend_tags( POST ./{user_name}/tags: Add tags to a user. """ - user.set_tags(user.get_tags() | set(tags)) + user.add_tags(tags) user.update() @@ -166,6 +166,6 @@ async def remove_tags( DELETE ./{user_name}/tags: Remove tags from a user. """ - user.set_tags(user.get_tags() - set(tags)) + user.remove_tags(tags) user.update()