add models.User.from_db()

This commit is contained in:
Jörn-Michael Miehe 2022-03-26 01:49:47 +00:00
parent 557bceed1f
commit 94fbab278c
2 changed files with 36 additions and 30 deletions

View file

@ -7,7 +7,7 @@ from __future__ import annotations
from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String, from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint) UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base() ORMBaseModel = declarative_base()
@ -29,6 +29,7 @@ class User(ORMBaseModel):
name = Column(String, primary_key=True, index=True) name = Column(String, primary_key=True, index=True)
password = Column(String, nullable=False) password = Column(String, nullable=False)
email = Column(String)
country = Column(String(2)) country = Column(String(2))
state = Column(String) state = Column(String)
@ -36,8 +37,6 @@ class User(ORMBaseModel):
organization = Column(String) organization = Column(String)
organizational_unit = Column(String) organizational_unit = Column(String)
email = Column(String)
capabilities: list[UserCapability] = relationship( capabilities: list[UserCapability] = relationship(
"UserCapability", lazy="joined", cascade="all, delete-orphan" "UserCapability", lazy="joined", cascade="all, delete-orphan"
) )
@ -45,6 +44,23 @@ class User(ORMBaseModel):
"Device", lazy="select", back_populates="owner" "Device", lazy="select", back_populates="owner"
) )
@classmethod
def from_db(
cls,
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
return (
db
.query(cls)
.filter(cls.name == name)
.first()
)
class Device(ORMBaseModel): class Device(ORMBaseModel):
__tablename__ = "devices" __tablename__ = "devices"

View file

@ -8,7 +8,7 @@ from typing import Any
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from .. import models from .. import models
@ -18,23 +18,22 @@ from .user_capability import UserCapability
class UserBase(BaseModel): class UserBase(BaseModel):
name: str name: str
country: str
state: str
city: str
organization: str
organizational_unit: str
email: str email: str
capabilities: list[UserCapability] = []
country: str | None = Field(default=None, repr=False)
state: str | None = Field(default=None, repr=False)
city: str | None = Field(default=None, repr=False)
organization: str | None = Field(default=None, repr=False)
organizational_unit: str | None = Field(default=None, repr=False)
class UserCreate(UserBase): class UserCreate(UserBase):
password: str password: str
class User(UserBase): class User(UserBase):
capabilities: list[UserCapability] = []
devices: list[Device] = Field( devices: list[Device] = Field(
default=[], repr=False default=[], repr=False
) )
@ -64,15 +63,11 @@ class User(UserBase):
Load user from database by name. Load user from database by name.
""" """
try: if (db_user := models.User.from_db(db, name)) is None:
db_user = models.User(name=name)
db.refresh(db_user)
return cls.from_orm(db_user)
except InvalidRequestError:
return None return None
return cls.from_orm(db_user)
@classmethod @classmethod
def create( def create(
cls, cls,
@ -88,6 +83,7 @@ class User(UserBase):
db_user = models.User( db_user = models.User(
name=user.name, name=user.name,
password=crypt_context.hash(user.password), password=crypt_context.hash(user.password),
email=user.email,
capabilities=[ capabilities=[
capability.model capability.model
for capability in user.capabilities for capability in user.capabilities
@ -117,10 +113,7 @@ class User(UserBase):
Authenticate with name/password against users in database. Authenticate with name/password against users in database.
""" """
db_user = models.User(name=self.name) if (db_user := models.User.from_db(db, self.name)) is None:
db.refresh(db_user)
if db_user is None:
# nonexistent user, fake doing password verification # nonexistent user, fake doing password verification
crypt_context.dummy_verify() crypt_context.dummy_verify()
return False return False
@ -141,8 +134,8 @@ class User(UserBase):
Update this user in the database. Update this user in the database.
""" """
db_user = models.User(name=self.name) if (db_user := models.User.from_db(db, self.name)) is None:
db.refresh(db_user) return None
for capability in db_user.capabilities: for capability in db_user.capabilities:
db.delete(capability) db.delete(capability)
@ -162,10 +155,7 @@ class User(UserBase):
Delete this user from the database. Delete this user from the database.
""" """
db_user = models.User(name=self.name) if (db_user := models.User.from_db(db, self.name)) is None:
db.refresh(db_user)
if db_user is None:
# nonexistent user # nonexistent user
return False return False