Compare commits

..

4 commits

5 changed files with 63 additions and 30 deletions

View file

@ -6,6 +6,19 @@ from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel from .models import ORMBaseModel
class SessionManager:
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection: class Connection:
engine: Engine | None = None engine: Engine | None = None
session_local: sessionmaker | None = None session_local: sessionmaker | None = None
@ -18,6 +31,13 @@ class Connection:
) )
ORMBaseModel.metadata.create_all(bind=engine) ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
def use(cls) -> SessionManager | None:
if cls.session_local is None:
return None
return SessionManager(cls.session_local())
@classmethod @classmethod
async def get(cls) -> Generator[Session | None, None, None]: async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None: if cls.session_local is None:

View file

@ -1,9 +1,11 @@
from __future__ import annotations
import datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint) UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base() ORMBaseModel = declarative_base()
@ -17,6 +19,13 @@ class User(ORMBaseModel):
capabilities = relationship("UserCapability", lazy="joined") capabilities = relationship("UserCapability", lazy="joined")
certificates = relationship("Certificate", lazy="joined") certificates = relationship("Certificate", lazy="joined")
@classmethod
def load(cls, db: Session, name: str) -> User | None:
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel): class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities" __tablename__ = "user_capabilities"
@ -29,9 +38,6 @@ class UserCapability(ORMBaseModel):
) )
capability = Column(String, primary_key=True) capability = Column(String, primary_key=True)
def __str__(self) -> str:
return self.capability
class DistinguishedName(ORMBaseModel): class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names" __tablename__ = "distinguished_names"

View file

@ -30,8 +30,19 @@ class Certificate(CertificateBase):
class UserCapability(Enum): class UserCapability(Enum):
admin = "admin" admin = "admin"
def __str__(self) -> str: @classmethod
return self._value_ def from_value(cls, value) -> UserCapability:
if isinstance(value, cls):
# use simple value
return value
elif isinstance(value, models.UserCapability):
# create from db
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
class UserBase(BaseModel): class UserBase(BaseModel):
@ -53,10 +64,10 @@ class User(UserBase):
@classmethod @classmethod
def unify_capabilities( def unify_capabilities(
cls, cls,
value: list[models.UserCapability | str] value: list[models.UserCapability | UserCapability | str]
) -> list[UserCapability]: ) -> list[UserCapability]:
return [ return [
UserCapability(str(capability)) UserCapability.from_value(capability)
for capability in value for capability in value
] ]
@ -66,15 +77,10 @@ class User(UserBase):
db: Session, db: Session,
name: str, name: str,
) -> User | None: ) -> User | None:
user = (db if (db_user := models.User.load(db, name)) is None:
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
return None return None
return cls.from_orm(user) return cls.from_orm(db_user)
@classmethod @classmethod
def login( def login(
@ -84,21 +90,16 @@ class User(UserBase):
password: str, password: str,
crypt_context: CryptContext, crypt_context: CryptContext,
) -> User | None: ) -> User | None:
user = (db if (db_user := models.User.load(db, name)) is None:
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
# inexistent user, fake doing password verification # inexistent user, fake doing password verification
crypt_context.dummy_verify() crypt_context.dummy_verify()
return None return None
if not crypt_context.verify(password, user.password): if not crypt_context.verify(password, db_user.password):
# password hash mismatch # password hash mismatch
return None return None
return cls.from_orm(user) return cls.from_orm(db_user)
@classmethod @classmethod
def create( def create(
@ -127,10 +128,16 @@ class User(UserBase):
self, self,
db: Session, db: Session,
capabilities: list[UserCapability], capabilities: list[UserCapability],
) -> bool: ) -> None:
# TODO for capability in capabilities:
if capability not in self.capabilities:
cap = models.UserCapability(
user_name=self.name,
capability=capability.value,
)
db.add(cap)
return True db.commit()
class DistinguishedNameBase(BaseModel): class DistinguishedNameBase(BaseModel):

View file

@ -39,7 +39,7 @@ async def on_startup() -> None:
Connection.connect(await current_config.db.db_engine) Connection.connect(await current_config.db.db_engine)
# some testing # some testing
async for db in Connection.get(): with Connection.use() as db:
print(schemas.User.from_db(db, "admin")) print(schemas.User.from_db(db, "admin"))
print(schemas.User.from_db(db, "nonexistent")) print(schemas.User.from_db(db, "nonexistent"))

View file

@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin")
) )
async def install( async def install(
config: Config, config: Config,
user: schemas.UserCreate, admin_user: schemas.UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
if current_config is not None: if current_config is not None:
@ -25,10 +25,10 @@ async def install(
await config.save() await config.save()
Connection.connect(await config.db.db_engine) Connection.connect(await config.db.db_engine)
async for db in Connection.get(): with Connection.use() as db:
admin_user = schemas.User.create( admin_user = schemas.User.create(
db=db, db=db,
user=user, user=admin_user,
crypt_context=await config.crypto.crypt_context, crypt_context=await config.crypto.crypt_context,
) )