Compare commits

..

No commits in common. "00bdf88b6e234fad2fac38896843013b0e008d06" and "1c1ea694d1e08e091787d700196a1accafc28890" have entirely different histories.

5 changed files with 30 additions and 63 deletions

View file

@ -6,19 +6,6 @@ from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel from .models import ORMBaseModel
class SessionManager:
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection: class Connection:
engine: Engine | None = None engine: Engine | None = None
session_local: sessionmaker | None = None session_local: sessionmaker | None = None
@ -31,13 +18,6 @@ class Connection:
) )
ORMBaseModel.metadata.create_all(bind=engine) ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
def use(cls) -> SessionManager | None:
if cls.session_local is None:
return None
return SessionManager(cls.session_local())
@classmethod @classmethod
async def get(cls) -> Generator[Session | None, None, None]: async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None: if cls.session_local is None:

View file

@ -1,11 +1,9 @@
from __future__ import annotations
import datetime import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint) UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship from sqlalchemy.orm import relationship
ORMBaseModel = declarative_base() ORMBaseModel = declarative_base()
@ -19,13 +17,6 @@ class User(ORMBaseModel):
capabilities = relationship("UserCapability", lazy="joined") capabilities = relationship("UserCapability", lazy="joined")
certificates = relationship("Certificate", lazy="joined") certificates = relationship("Certificate", lazy="joined")
@classmethod
def load(cls, db: Session, name: str) -> User | None:
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel): class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities" __tablename__ = "user_capabilities"
@ -38,6 +29,9 @@ class UserCapability(ORMBaseModel):
) )
capability = Column(String, primary_key=True) capability = Column(String, primary_key=True)
def __str__(self) -> str:
return self.capability
class DistinguishedName(ORMBaseModel): class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names" __tablename__ = "distinguished_names"

View file

@ -30,19 +30,8 @@ class Certificate(CertificateBase):
class UserCapability(Enum): class UserCapability(Enum):
admin = "admin" admin = "admin"
@classmethod def __str__(self) -> str:
def from_value(cls, value) -> UserCapability: return self._value_
if isinstance(value, cls):
# use simple value
return value
elif isinstance(value, models.UserCapability):
# create from db
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
class UserBase(BaseModel): class UserBase(BaseModel):
@ -64,10 +53,10 @@ class User(UserBase):
@classmethod @classmethod
def unify_capabilities( def unify_capabilities(
cls, cls,
value: list[models.UserCapability | UserCapability | str] value: list[models.UserCapability | str]
) -> list[UserCapability]: ) -> list[UserCapability]:
return [ return [
UserCapability.from_value(capability) UserCapability(str(capability))
for capability in value for capability in value
] ]
@ -77,10 +66,15 @@ class User(UserBase):
db: Session, db: Session,
name: str, name: str,
) -> User | None: ) -> User | None:
if (db_user := models.User.load(db, name)) is None: user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
return None return None
return cls.from_orm(db_user) return cls.from_orm(user)
@classmethod @classmethod
def login( def login(
@ -90,16 +84,21 @@ class User(UserBase):
password: str, password: str,
crypt_context: CryptContext, crypt_context: CryptContext,
) -> User | None: ) -> User | None:
if (db_user := models.User.load(db, name)) is None: user = (db
.query(models.User)
.filter(models.User.name == name)
.first())
if user is None:
# inexistent user, fake doing password verification # inexistent user, fake doing password verification
crypt_context.dummy_verify() crypt_context.dummy_verify()
return None return None
if not crypt_context.verify(password, db_user.password): if not crypt_context.verify(password, user.password):
# password hash mismatch # password hash mismatch
return None return None
return cls.from_orm(db_user) return cls.from_orm(user)
@classmethod @classmethod
def create( def create(
@ -128,16 +127,10 @@ class User(UserBase):
self, self,
db: Session, db: Session,
capabilities: list[UserCapability], capabilities: list[UserCapability],
) -> None: ) -> bool:
for capability in capabilities: # TODO
if capability not in self.capabilities:
cap = models.UserCapability(
user_name=self.name,
capability=capability.value,
)
db.add(cap)
db.commit() return True
class DistinguishedNameBase(BaseModel): class DistinguishedNameBase(BaseModel):

View file

@ -39,7 +39,7 @@ async def on_startup() -> None:
Connection.connect(await current_config.db.db_engine) Connection.connect(await current_config.db.db_engine)
# some testing # some testing
with Connection.use() as db: async for db in Connection.get():
print(schemas.User.from_db(db, "admin")) print(schemas.User.from_db(db, "admin"))
print(schemas.User.from_db(db, "nonexistent")) print(schemas.User.from_db(db, "nonexistent"))

View file

@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin")
) )
async def install( async def install(
config: Config, config: Config,
admin_user: schemas.UserCreate, user: schemas.UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
if current_config is not None: if current_config is not None:
@ -25,10 +25,10 @@ async def install(
await config.save() await config.save()
Connection.connect(await config.db.db_engine) Connection.connect(await config.db.db_engine)
with Connection.use() as db: async for db in Connection.get():
admin_user = schemas.User.create( admin_user = schemas.User.create(
db=db, db=db,
user=admin_user, user=user,
crypt_context=await config.crypto.crypt_context, crypt_context=await config.crypto.crypt_context,
) )