diff --git a/api/kiwi_vpn_api/db/models.py b/api/kiwi_vpn_api/db/models.py index 32343c1..6376a2e 100644 --- a/api/kiwi_vpn_api/db/models.py +++ b/api/kiwi_vpn_api/db/models.py @@ -29,6 +29,9 @@ class UserCapability(ORMBaseModel): ) capability = Column(String, primary_key=True) + def __str__(self) -> str: + return self.capability + class DistinguishedName(ORMBaseModel): __tablename__ = "distinguished_names" diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemas.py index e6eed79..6c3dbed 100644 --- a/api/kiwi_vpn_api/db/schemas.py +++ b/api/kiwi_vpn_api/db/schemas.py @@ -1,6 +1,7 @@ from __future__ import annotations from datetime import datetime +from enum import Enum from passlib.context import CryptContext from pydantic import BaseModel, validator @@ -26,22 +27,12 @@ class Certificate(CertificateBase): orm_mode = True +class UserCapability(Enum): + admin = "admin" + + class UserBase(BaseModel): name: str - capabilities: list[str] - - @validator("capabilities", pre=True) - @classmethod - def unify_capabilities( - cls, - value: list[models.UserCapability | str] - ) -> list[str]: - return [ - capability.capability - if isinstance(capability, models.UserCapability) - else str(capability) - for capability in value - ] class UserCreate(UserBase): @@ -50,10 +41,22 @@ class UserCreate(UserBase): class User(UserBase): certificates: list[Certificate] + capabilities: list[UserCapability] class Config: orm_mode = True + @validator("capabilities", pre=True) + @classmethod + def unify_capabilities( + cls, + value: list[models.UserCapability | str] + ) -> list[UserCapability]: + return [ + UserCapability(str(capability)) + for capability in value + ] + @classmethod def from_db( cls, @@ -105,10 +108,7 @@ class User(UserBase): user = models.User( name=user.name, password=crypt_context.hash(user.password), - capabilities=[ - models.UserCapability(capability=capability) - for capability in user.capabilities - ] + capabilities=[models.UserCapability(capability="admin")], ) db.add(user) diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 87c14b5..0c80e3d 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -26,7 +26,7 @@ async def install( Connection.connect(await config.db.db_engine) async for db in Connection.get(): - user.capabilities.append("admin") + # user.capabilities.append("admin") schemas.User.create( db=db, @@ -52,7 +52,8 @@ async def set_config( if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - if current_user is None or "admin" not in current_user.capabilities: + if (current_user is None + or schemas.UserCapability.admin not in current_user.capabilities): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) await new_config.save() diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 0f4552c..ed38653 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -69,7 +69,8 @@ async def add_user( if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - if current_user is None or "admin" not in current_user.capabilities: + if (current_user is None + or schemas.UserCapability.admin not in current_user.capabilities): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) new_user = schemas.User.create(