less messy cap unification

This commit is contained in:
Jörn-Michael Miehe 2022-03-19 23:56:11 +00:00
parent 1c1ea694d1
commit 54c5e7ae8a
2 changed files with 33 additions and 24 deletions

View file

@ -1,9 +1,11 @@
from __future__ import annotations
import datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint) UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base() ORMBaseModel = declarative_base()
@ -17,6 +19,13 @@ class User(ORMBaseModel):
capabilities = relationship("UserCapability", lazy="joined") capabilities = relationship("UserCapability", lazy="joined")
certificates = relationship("Certificate", lazy="joined") certificates = relationship("Certificate", lazy="joined")
@classmethod
def load(cls, db: Session, name: str) -> User | None:
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel): class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities" __tablename__ = "user_capabilities"
@ -29,9 +38,6 @@ class UserCapability(ORMBaseModel):
) )
capability = Column(String, primary_key=True) capability = Column(String, primary_key=True)
def __str__(self) -> str:
return self.capability
class DistinguishedName(ORMBaseModel): class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names" __tablename__ = "distinguished_names"

View file

@ -30,8 +30,19 @@ class Certificate(CertificateBase):
class UserCapability(Enum): class UserCapability(Enum):
admin = "admin" admin = "admin"
def __str__(self) -> str: @classmethod
return self._value_ def from_value(cls, value) -> UserCapability:
if isinstance(value, cls):
# use simple value
return value
elif isinstance(value, models.UserCapability):
# create from db
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
class UserBase(BaseModel): class UserBase(BaseModel):
@ -53,10 +64,10 @@ class User(UserBase):
@classmethod @classmethod
def unify_capabilities( def unify_capabilities(
cls, cls,
value: list[models.UserCapability | str] value: list[models.UserCapability | UserCapability | str]
) -> list[UserCapability]: ) -> list[UserCapability]:
return [ return [
UserCapability(str(capability)) UserCapability.from_value(capability)
for capability in value for capability in value
] ]
@ -66,15 +77,10 @@ class User(UserBase):
db: Session, db: Session,
name: str, name: str,
) -> User | None: ) -> User | None:
user = (db if (db_user := models.User.load(db, name)) is None:
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
return None return None
return cls.from_orm(user) return cls.from_orm(db_user)
@classmethod @classmethod
def login( def login(
@ -84,21 +90,16 @@ class User(UserBase):
password: str, password: str,
crypt_context: CryptContext, crypt_context: CryptContext,
) -> User | None: ) -> User | None:
user = (db if (db_user := models.User.load(db, name)) is None:
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
# inexistent user, fake doing password verification # inexistent user, fake doing password verification
crypt_context.dummy_verify() crypt_context.dummy_verify()
return None return None
if not crypt_context.verify(password, user.password): if not crypt_context.verify(password, db_user.password):
# password hash mismatch # password hash mismatch
return None return None
return cls.from_orm(user) return cls.from_orm(db_user)
@classmethod @classmethod
def create( def create(
@ -128,7 +129,9 @@ class User(UserBase):
db: Session, db: Session,
capabilities: list[UserCapability], capabilities: list[UserCapability],
) -> bool: ) -> bool:
for capability in capabilities:
# TODO # TODO
pass
return True return True