diff --git a/api/kiwi_vpn_api/db/models.py b/api/kiwi_vpn_api/db/models.py index 1097400..fc41d6f 100644 --- a/api/kiwi_vpn_api/db/models.py +++ b/api/kiwi_vpn_api/db/models.py @@ -7,7 +7,7 @@ from __future__ import annotations from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String, UniqueConstraint) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Session, relationship ORMBaseModel = declarative_base() @@ -29,6 +29,7 @@ class User(ORMBaseModel): name = Column(String, primary_key=True, index=True) password = Column(String, nullable=False) + email = Column(String) country = Column(String(2)) state = Column(String) @@ -36,8 +37,6 @@ class User(ORMBaseModel): organization = Column(String) organizational_unit = Column(String) - email = Column(String) - capabilities: list[UserCapability] = relationship( "UserCapability", lazy="joined", cascade="all, delete-orphan" ) @@ -45,6 +44,23 @@ class User(ORMBaseModel): "Device", lazy="select", back_populates="owner" ) + @classmethod + def from_db( + cls, + db: Session, + name: str, + ) -> User | None: + """ + Load user from database by name. + """ + + return ( + db + .query(cls) + .filter(cls.name == name) + .first() + ) + class Device(ORMBaseModel): __tablename__ = "devices" diff --git a/api/kiwi_vpn_api/db/schemata/user.py b/api/kiwi_vpn_api/db/schemata/user.py index 5789f05..b31e31e 100644 --- a/api/kiwi_vpn_api/db/schemata/user.py +++ b/api/kiwi_vpn_api/db/schemata/user.py @@ -8,7 +8,7 @@ from typing import Any from passlib.context import CryptContext from pydantic import BaseModel, Field, validator -from sqlalchemy.exc import IntegrityError, InvalidRequestError +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from .. import models @@ -18,23 +18,22 @@ from .user_capability import UserCapability class UserBase(BaseModel): name: str - - country: str - state: str - city: str - organization: str - organizational_unit: str - email: str + capabilities: list[UserCapability] = [] + + country: str | None = Field(default=None, repr=False) + state: str | None = Field(default=None, repr=False) + city: str | None = Field(default=None, repr=False) + organization: str | None = Field(default=None, repr=False) + organizational_unit: str | None = Field(default=None, repr=False) + class UserCreate(UserBase): password: str class User(UserBase): - capabilities: list[UserCapability] = [] - devices: list[Device] = Field( default=[], repr=False ) @@ -64,15 +63,11 @@ class User(UserBase): Load user from database by name. """ - try: - db_user = models.User(name=name) - db.refresh(db_user) - - return cls.from_orm(db_user) - - except InvalidRequestError: + if (db_user := models.User.from_db(db, name)) is None: return None + return cls.from_orm(db_user) + @classmethod def create( cls, @@ -88,6 +83,7 @@ class User(UserBase): db_user = models.User( name=user.name, password=crypt_context.hash(user.password), + email=user.email, capabilities=[ capability.model for capability in user.capabilities @@ -117,10 +113,7 @@ class User(UserBase): Authenticate with name/password against users in database. """ - db_user = models.User(name=self.name) - db.refresh(db_user) - - if db_user is None: + if (db_user := models.User.from_db(db, self.name)) is None: # nonexistent user, fake doing password verification crypt_context.dummy_verify() return False @@ -141,8 +134,8 @@ class User(UserBase): Update this user in the database. """ - db_user = models.User(name=self.name) - db.refresh(db_user) + if (db_user := models.User.from_db(db, self.name)) is None: + return None for capability in db_user.capabilities: db.delete(capability) @@ -162,10 +155,7 @@ class User(UserBase): Delete this user from the database. """ - db_user = models.User(name=self.name) - db.refresh(db_user) - - if db_user is None: + if (db_user := models.User.from_db(db, self.name)) is None: # nonexistent user return False