diff --git a/api/kiwi_vpn_api/db/models.py b/api/kiwi_vpn_api/db/models.py index 6376a2e..f7a107d 100644 --- a/api/kiwi_vpn_api/db/models.py +++ b/api/kiwi_vpn_api/db/models.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import datetime from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, UniqueConstraint) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session, relationship ORMBaseModel = declarative_base() @@ -17,6 +19,13 @@ class User(ORMBaseModel): capabilities = relationship("UserCapability", lazy="joined") certificates = relationship("Certificate", lazy="joined") + @classmethod + def load(cls, db: Session, name: str) -> User | None: + return (db + .query(User) + .filter(User.name == name) + .first()) + class UserCapability(ORMBaseModel): __tablename__ = "user_capabilities" @@ -29,9 +38,6 @@ class UserCapability(ORMBaseModel): ) capability = Column(String, primary_key=True) - def __str__(self) -> str: - return self.capability - class DistinguishedName(ORMBaseModel): __tablename__ = "distinguished_names" diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemas.py index 2663e72..0e52185 100644 --- a/api/kiwi_vpn_api/db/schemas.py +++ b/api/kiwi_vpn_api/db/schemas.py @@ -30,8 +30,19 @@ class Certificate(CertificateBase): class UserCapability(Enum): admin = "admin" - def __str__(self) -> str: - return self._value_ + @classmethod + def from_value(cls, value) -> UserCapability: + if isinstance(value, cls): + # use simple value + return value + + elif isinstance(value, models.UserCapability): + # create from db + return cls(value.capability) + + else: + # create from string representation + return cls(str(value)) class UserBase(BaseModel): @@ -53,10 +64,10 @@ class User(UserBase): @classmethod def unify_capabilities( cls, - value: list[models.UserCapability | str] + value: list[models.UserCapability | UserCapability | str] ) -> list[UserCapability]: return [ - UserCapability(str(capability)) + UserCapability.from_value(capability) for capability in value ] @@ -66,15 +77,10 @@ class User(UserBase): db: Session, name: str, ) -> User | None: - user = (db - .query(models.User) - .filter(models.User.name == name) - .first()) - - if user is None: + if (db_user := models.User.load(db, name)) is None: return None - return cls.from_orm(user) + return cls.from_orm(db_user) @classmethod def login( @@ -84,21 +90,16 @@ class User(UserBase): password: str, crypt_context: CryptContext, ) -> User | None: - user = (db - .query(models.User) - .filter(models.User.name == name) - .first()) - - if user is None: + if (db_user := models.User.load(db, name)) is None: # inexistent user, fake doing password verification crypt_context.dummy_verify() return None - if not crypt_context.verify(password, user.password): + if not crypt_context.verify(password, db_user.password): # password hash mismatch return None - return cls.from_orm(user) + return cls.from_orm(db_user) @classmethod def create( @@ -128,7 +129,9 @@ class User(UserBase): db: Session, capabilities: list[UserCapability], ) -> bool: - # TODO + for capability in capabilities: + # TODO + pass return True