Compare commits

..

No commits in common. "00bdf88b6e234fad2fac38896843013b0e008d06" and "1c1ea694d1e08e091787d700196a1accafc28890" have entirely different histories.

5 changed files with 30 additions and 63 deletions

View file

@ -6,19 +6,6 @@ from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class SessionManager:
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection:
engine: Engine | None = None
session_local: sessionmaker | None = None
@ -31,13 +18,6 @@ class Connection:
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
def use(cls) -> SessionManager | None:
if cls.session_local is None:
return None
return SessionManager(cls.session_local())
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None:

View file

@ -1,11 +1,9 @@
from __future__ import annotations
import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
from sqlalchemy.orm import relationship
ORMBaseModel = declarative_base()
@ -19,13 +17,6 @@ class User(ORMBaseModel):
capabilities = relationship("UserCapability", lazy="joined")
certificates = relationship("Certificate", lazy="joined")
@classmethod
def load(cls, db: Session, name: str) -> User | None:
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities"
@ -38,6 +29,9 @@ class UserCapability(ORMBaseModel):
)
capability = Column(String, primary_key=True)
def __str__(self) -> str:
return self.capability
class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names"

View file

@ -30,19 +30,8 @@ class Certificate(CertificateBase):
class UserCapability(Enum):
admin = "admin"
@classmethod
def from_value(cls, value) -> UserCapability:
if isinstance(value, cls):
# use simple value
return value
elif isinstance(value, models.UserCapability):
# create from db
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
def __str__(self) -> str:
return self._value_
class UserBase(BaseModel):
@ -64,10 +53,10 @@ class User(UserBase):
@classmethod
def unify_capabilities(
cls,
value: list[models.UserCapability | UserCapability | str]
value: list[models.UserCapability | str]
) -> list[UserCapability]:
return [
UserCapability.from_value(capability)
UserCapability(str(capability))
for capability in value
]
@ -77,10 +66,15 @@ class User(UserBase):
db: Session,
name: str,
) -> User | None:
if (db_user := models.User.load(db, name)) is None:
user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
return None
return cls.from_orm(db_user)
return cls.from_orm(user)
@classmethod
def login(
@ -90,16 +84,21 @@ class User(UserBase):
password: str,
crypt_context: CryptContext,
) -> User | None:
if (db_user := models.User.load(db, name)) is None:
user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
# inexistent user, fake doing password verification
crypt_context.dummy_verify()
return None
if not crypt_context.verify(password, db_user.password):
if not crypt_context.verify(password, user.password):
# password hash mismatch
return None
return cls.from_orm(db_user)
return cls.from_orm(user)
@classmethod
def create(
@ -128,16 +127,10 @@ class User(UserBase):
self,
db: Session,
capabilities: list[UserCapability],
) -> None:
for capability in capabilities:
if capability not in self.capabilities:
cap = models.UserCapability(
user_name=self.name,
capability=capability.value,
)
db.add(cap)
) -> bool:
# TODO
db.commit()
return True
class DistinguishedNameBase(BaseModel):

View file

@ -39,7 +39,7 @@ async def on_startup() -> None:
Connection.connect(await current_config.db.db_engine)
# some testing
with Connection.use() as db:
async for db in Connection.get():
print(schemas.User.from_db(db, "admin"))
print(schemas.User.from_db(db, "nonexistent"))

View file

@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin")
)
async def install(
config: Config,
admin_user: schemas.UserCreate,
user: schemas.UserCreate,
current_config: Config | None = Depends(Config.load),
):
if current_config is not None:
@ -25,10 +25,10 @@ async def install(
await config.save()
Connection.connect(await config.db.db_engine)
with Connection.use() as db:
async for db in Connection.get():
admin_user = schemas.User.create(
db=db,
user=admin_user,
user=user,
crypt_context=await config.crypto.crypt_context,
)