Compare commits
No commits in common. "00bdf88b6e234fad2fac38896843013b0e008d06" and "1c1ea694d1e08e091787d700196a1accafc28890" have entirely different histories.
00bdf88b6e
...
1c1ea694d1
5 changed files with 30 additions and 63 deletions
|
|
@ -6,19 +6,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||||
from .models import ORMBaseModel
|
from .models import ORMBaseModel
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
|
||||||
__session: Session
|
|
||||||
|
|
||||||
def __init__(self, session: Session) -> None:
|
|
||||||
self.__session = session
|
|
||||||
|
|
||||||
def __enter__(self) -> Session:
|
|
||||||
return self.__session
|
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
|
||||||
self.__session.close()
|
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
class Connection:
|
||||||
engine: Engine | None = None
|
engine: Engine | None = None
|
||||||
session_local: sessionmaker | None = None
|
session_local: sessionmaker | None = None
|
||||||
|
|
@ -31,13 +18,6 @@ class Connection:
|
||||||
)
|
)
|
||||||
ORMBaseModel.metadata.create_all(bind=engine)
|
ORMBaseModel.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def use(cls) -> SessionManager | None:
|
|
||||||
if cls.session_local is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return SessionManager(cls.session_local())
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls) -> Generator[Session | None, None, None]:
|
async def get(cls) -> Generator[Session | None, None, None]:
|
||||||
if cls.session_local is None:
|
if cls.session_local is None:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
|
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
|
||||||
UniqueConstraint)
|
UniqueConstraint)
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import Session, relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
ORMBaseModel = declarative_base()
|
ORMBaseModel = declarative_base()
|
||||||
|
|
||||||
|
|
@ -19,13 +17,6 @@ class User(ORMBaseModel):
|
||||||
capabilities = relationship("UserCapability", lazy="joined")
|
capabilities = relationship("UserCapability", lazy="joined")
|
||||||
certificates = relationship("Certificate", lazy="joined")
|
certificates = relationship("Certificate", lazy="joined")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, db: Session, name: str) -> User | None:
|
|
||||||
return (db
|
|
||||||
.query(User)
|
|
||||||
.filter(User.name == name)
|
|
||||||
.first())
|
|
||||||
|
|
||||||
|
|
||||||
class UserCapability(ORMBaseModel):
|
class UserCapability(ORMBaseModel):
|
||||||
__tablename__ = "user_capabilities"
|
__tablename__ = "user_capabilities"
|
||||||
|
|
@ -38,6 +29,9 @@ class UserCapability(ORMBaseModel):
|
||||||
)
|
)
|
||||||
capability = Column(String, primary_key=True)
|
capability = Column(String, primary_key=True)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.capability
|
||||||
|
|
||||||
|
|
||||||
class DistinguishedName(ORMBaseModel):
|
class DistinguishedName(ORMBaseModel):
|
||||||
__tablename__ = "distinguished_names"
|
__tablename__ = "distinguished_names"
|
||||||
|
|
|
||||||
|
|
@ -30,19 +30,8 @@ class Certificate(CertificateBase):
|
||||||
class UserCapability(Enum):
|
class UserCapability(Enum):
|
||||||
admin = "admin"
|
admin = "admin"
|
||||||
|
|
||||||
@classmethod
|
def __str__(self) -> str:
|
||||||
def from_value(cls, value) -> UserCapability:
|
return self._value_
|
||||||
if isinstance(value, cls):
|
|
||||||
# use simple value
|
|
||||||
return value
|
|
||||||
|
|
||||||
elif isinstance(value, models.UserCapability):
|
|
||||||
# create from db
|
|
||||||
return cls(value.capability)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# create from string representation
|
|
||||||
return cls(str(value))
|
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
|
|
@ -64,10 +53,10 @@ class User(UserBase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def unify_capabilities(
|
def unify_capabilities(
|
||||||
cls,
|
cls,
|
||||||
value: list[models.UserCapability | UserCapability | str]
|
value: list[models.UserCapability | str]
|
||||||
) -> list[UserCapability]:
|
) -> list[UserCapability]:
|
||||||
return [
|
return [
|
||||||
UserCapability.from_value(capability)
|
UserCapability(str(capability))
|
||||||
for capability in value
|
for capability in value
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -77,10 +66,15 @@ class User(UserBase):
|
||||||
db: Session,
|
db: Session,
|
||||||
name: str,
|
name: str,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
if (db_user := models.User.load(db, name)) is None:
|
user = (db
|
||||||
|
.query(models.User)
|
||||||
|
.filter(models.User.name == name)
|
||||||
|
.first())
|
||||||
|
|
||||||
|
if user is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return cls.from_orm(db_user)
|
return cls.from_orm(user)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def login(
|
def login(
|
||||||
|
|
@ -90,16 +84,21 @@ class User(UserBase):
|
||||||
password: str,
|
password: str,
|
||||||
crypt_context: CryptContext,
|
crypt_context: CryptContext,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
if (db_user := models.User.load(db, name)) is None:
|
user = (db
|
||||||
|
.query(models.User)
|
||||||
|
.filter(models.User.name == name)
|
||||||
|
.first())
|
||||||
|
|
||||||
|
if user is None:
|
||||||
# inexistent user, fake doing password verification
|
# inexistent user, fake doing password verification
|
||||||
crypt_context.dummy_verify()
|
crypt_context.dummy_verify()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not crypt_context.verify(password, db_user.password):
|
if not crypt_context.verify(password, user.password):
|
||||||
# password hash mismatch
|
# password hash mismatch
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return cls.from_orm(db_user)
|
return cls.from_orm(user)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
|
@ -128,16 +127,10 @@ class User(UserBase):
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
capabilities: list[UserCapability],
|
capabilities: list[UserCapability],
|
||||||
) -> None:
|
) -> bool:
|
||||||
for capability in capabilities:
|
# TODO
|
||||||
if capability not in self.capabilities:
|
|
||||||
cap = models.UserCapability(
|
|
||||||
user_name=self.name,
|
|
||||||
capability=capability.value,
|
|
||||||
)
|
|
||||||
db.add(cap)
|
|
||||||
|
|
||||||
db.commit()
|
return True
|
||||||
|
|
||||||
|
|
||||||
class DistinguishedNameBase(BaseModel):
|
class DistinguishedNameBase(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ async def on_startup() -> None:
|
||||||
Connection.connect(await current_config.db.db_engine)
|
Connection.connect(await current_config.db.db_engine)
|
||||||
|
|
||||||
# some testing
|
# some testing
|
||||||
with Connection.use() as db:
|
async for db in Connection.get():
|
||||||
print(schemas.User.from_db(db, "admin"))
|
print(schemas.User.from_db(db, "admin"))
|
||||||
print(schemas.User.from_db(db, "nonexistent"))
|
print(schemas.User.from_db(db, "nonexistent"))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ router = APIRouter(prefix="/admin")
|
||||||
)
|
)
|
||||||
async def install(
|
async def install(
|
||||||
config: Config,
|
config: Config,
|
||||||
admin_user: schemas.UserCreate,
|
user: schemas.UserCreate,
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
):
|
):
|
||||||
if current_config is not None:
|
if current_config is not None:
|
||||||
|
|
@ -25,10 +25,10 @@ async def install(
|
||||||
await config.save()
|
await config.save()
|
||||||
Connection.connect(await config.db.db_engine)
|
Connection.connect(await config.db.db_engine)
|
||||||
|
|
||||||
with Connection.use() as db:
|
async for db in Connection.get():
|
||||||
admin_user = schemas.User.create(
|
admin_user = schemas.User.create(
|
||||||
db=db,
|
db=db,
|
||||||
user=admin_user,
|
user=user,
|
||||||
crypt_context=await config.crypto.crypt_context,
|
crypt_context=await config.crypto.crypt_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue