Compare commits

...

5 commits

17 changed files with 39 additions and 628 deletions

View file

@ -1,4 +1,7 @@
from . import models, schemata from .capability import Capability
from .connection import Connection from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead
__all__ = ["Connection", "models", "schemata"] __all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate",
"User", "UserBase", "UserCreate", "UserRead"]

View file

@ -1,30 +1,4 @@
""" from sqlmodel import Session, SQLModel, create_engine
Utilities for handling SQLAlchemy database connections.
"""
from typing import Generator
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection: class Connection:
@ -32,44 +6,25 @@ class Connection:
Namespace for the database connection. Namespace for the database connection.
""" """
engine: Engine | None = None engine = None
session_local: sessionmaker | None = None
@classmethod @classmethod
def connect(cls, engine: Engine) -> None: def connect(cls, connection_url: str) -> None:
""" """
Connect ORM to a database engine. Connect ORM to a database engine.
""" """
cls.engine = engine cls.engine = create_engine(connection_url)
cls.session_local = sessionmaker( SQLModel.metadata.create_all(cls.engine)
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod @classmethod
def use(cls) -> SessionManager | None: @property
def session(cls) -> Session | None:
""" """
Create an ORM session using a context manager. Create an ORM session using a context manager.
""" """
if cls.session_local is None: if cls.engine is None:
return None return None
return SessionManager(cls.session_local()) return Session(cls.engine)
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()

View file

@ -1,82 +0,0 @@
"""
SQLAlchemy representation of database contents.
"""
from __future__ import annotations
from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base()
class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities"
user_name = Column(
String,
ForeignKey("users.name"),
primary_key=True,
index=True,
)
capability = Column(String, primary_key=True)
class User(ORMBaseModel):
__tablename__ = "users"
name = Column(String, primary_key=True, index=True)
password = Column(String, nullable=False)
email = Column(String)
country = Column(String(2))
state = Column(String)
city = Column(String)
organization = Column(String)
organizational_unit = Column(String)
capabilities: list[UserCapability] = relationship(
"UserCapability", lazy="joined", cascade="all, delete-orphan"
)
devices: list[Device] = relationship(
"Device", lazy="select", back_populates="owner"
)
@classmethod
def from_db(
cls,
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
return (
db
.query(cls)
.filter(cls.name == name)
.first()
)
class Device(ORMBaseModel):
__tablename__ = "devices"
id = Column(Integer, primary_key=True, autoincrement=True)
owner_name = Column(String, ForeignKey("users.name"))
name = Column(String)
type = Column(String)
expiry = Column(DateTime)
owner: User = relationship(
"User", lazy="joined", back_populates="devices"
)
UniqueConstraint(
owner_name,
name,
)

View file

@ -1,6 +0,0 @@
from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate
from .user_capability import UserCapability
__all__ = ["Device", "DeviceBase", "DeviceCreate",
"User", "UserBase", "UserCreate", "UserCapability"]

View file

@ -1,79 +0,0 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from datetime import datetime
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from .. import models
class DeviceBase(BaseModel):
name: str
type: str
expiry: datetime
class DeviceCreate(DeviceBase):
owner_name: str
class Device(DeviceBase):
class Config:
orm_mode = True
@classmethod
def create(
cls,
db: Session,
device: DeviceCreate,
) -> Device | None:
"""
Create a new device in the database.
"""
try:
db_device = models.Device(
owner_name=device.owner_name,
name=device.name,
type=device.type,
expiry=device.expiry,
)
db.add(db_device)
db.commit()
db.refresh(db_device)
return cls.from_orm(db_device)
except IntegrityError:
# device already existed
return None
def delete(
self,
db: Session,
) -> bool:
"""
Delete this device from the database.
"""
db_device = models.Device(
# owner_name=
name=self.name,
)
db.refresh(db_device)
if db_device is None:
# nonexistent device
return False
db.delete(db_device)
db.commit()
return True

View file

@ -1,164 +0,0 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from typing import Any
from passlib.context import CryptContext
from pydantic import BaseModel, Field, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from .. import models
from .device import Device
from .user_capability import UserCapability
class UserBase(BaseModel):
name: str
email: str
capabilities: list[UserCapability] = []
country: str | None = Field(default=None, repr=False)
state: str | None = Field(default=None, repr=False)
city: str | None = Field(default=None, repr=False)
organization: str | None = Field(default=None, repr=False)
organizational_unit: str | None = Field(default=None, repr=False)
class UserCreate(UserBase):
password: str
class User(UserBase):
devices: list[Device] = Field(
default=[], repr=False
)
class Config:
orm_mode = True
@validator("capabilities", pre=True)
@classmethod
def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
"""
Import the capabilities from various formats
"""
return [
UserCapability.from_value(capability)
for capability in value
]
@classmethod
def from_db(
cls,
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
if (db_user := models.User.from_db(db, name)) is None:
return None
return cls.from_orm(db_user)
@classmethod
def create(
cls,
db: Session,
user: UserCreate,
crypt_context: CryptContext,
) -> User | None:
"""
Create a new user in the database.
"""
try:
db_user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
email=user.email,
capabilities=[
capability.model
for capability in user.capabilities
],
)
db.add(db_user)
db.commit()
db.refresh(db_user)
return cls.from_orm(db_user)
except IntegrityError:
# user already existed
return None
def is_admin(self) -> bool:
return UserCapability.admin in self.capabilities
def authenticate(
self,
db: Session,
password: str,
crypt_context: CryptContext,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
if (db_user := models.User.from_db(db, self.name)) is None:
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return False
if not crypt_context.verify(password, db_user.password):
# password hash mismatch
return False
self.from_orm(db_user)
return True
def update(
self,
db: Session,
) -> None:
"""
Update this user in the database.
"""
if (db_user := models.User.from_db(db, self.name)) is None:
return None
for capability in db_user.capabilities:
db.delete(capability)
db_user.capabilities = [
capability.model
for capability in self.capabilities
]
db.commit()
def delete(
self,
db: Session,
) -> bool:
"""
Delete this user from the database.
"""
if (db_user := models.User.from_db(db, self.name)) is None:
# nonexistent user
return False
db.delete(db_user)
db.commit()
return True

View file

@ -1,43 +0,0 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from enum import Enum
from .. import models
class UserCapability(Enum):
admin = "admin"
login = "login"
issue = "issue"
renew = "renew"
def __repr__(self) -> str:
return self.value
@classmethod
def from_value(cls, value) -> UserCapability:
"""
Create UserCapability from various formats
"""
if isinstance(value, cls):
# value is already a UserCapability, use that
return value
elif isinstance(value, models.UserCapability):
# create from db format
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
@property
def model(self) -> models.UserCapability:
return models.UserCapability(
capability=self.value,
)

View file

@ -109,6 +109,9 @@ class User(UserBase, table=True):
db.delete(self) db.delete(self)
db.commit() db.commit()
def can(self, capability: Capability) -> bool:
return capability in self.get_capabilities()
def get_capabilities(self) -> set[Capability]: def get_capabilities(self) -> set[Capability]:
return set( return set(
capability._ capability._

View file

@ -1,7 +0,0 @@
from .capability import Capability
from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead
__all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate",
"User", "UserBase", "UserCreate", "UserRead"]

View file

@ -1,30 +0,0 @@
from sqlmodel import Session, SQLModel, create_engine
class Connection:
"""
Namespace for the database connection.
"""
engine = None
@classmethod
def connect(cls, connection_url: str) -> None:
"""
Connect ORM to a database engine.
"""
cls.engine = create_engine(connection_url)
SQLModel.metadata.create_all(cls.engine)
@classmethod
@property
def session(cls) -> Session | None:
"""
Create an ORM session using a context manager.
"""
if cls.engine is None:
return None
return Session(cls.engine)

View file

@ -13,11 +13,8 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db import Connection
# from .db.schemata import User # from .db.schemata import User
from .db_new import Capability from .db import Connection
from .db_new import Connection as Connection_new
from .db_new import Device, User
from .routers import main_router from .routers import main_router
settings = Settings.get() settings = Settings.get()
@ -47,46 +44,13 @@ async def on_startup() -> None:
# check if configured # check if configured
if (current_config := await Config.load()) is not None: if (current_config := await Config.load()) is not None:
# connect to database # connect to database
Connection.connect(await current_config.db.db_engine) Connection.connect("sqlite:///tmp/v2.db")
# # some testing # # some testing
# with Connection.use() as db: # with Connection.use() as db:
# print(User.from_db(db, "admin")) # print(User.from_db(db, "admin"))
# print(User.from_db(db, "nonexistent")) # print(User.from_db(db, "nonexistent"))
Connection_new.connect("sqlite:///tmp/v2.db")
User.create(
name="Uwe",
password_clear="ulf",
email="uwe@feh.de",
)
print(User.get(name="Uwe"))
print(User.authenticate("Uwe", "uwe"))
uwe = User.authenticate("Uwe", "ulf")
uwe.set_capabilities([Capability.admin])
uwe.update()
print(uwe.get_capabilities())
uwe.set_capabilities([])
uwe.update()
print(uwe.get_capabilities())
with Connection_new.session as db:
db.add(uwe)
print(uwe.devices)
ipad = Device.create(
owner_name="Uwe",
name="iPad",
type="tablet",
)
# ipad = Device.
print(ipad)
def main() -> None: def main() -> None:
uvicorn.run( uvicorn.run(

View file

@ -1,10 +1,12 @@
from fastapi import APIRouter from fastapi import APIRouter
from . import admin, user from . import admin
# from . import user
main_router = APIRouter(prefix="/api/v1") main_router = APIRouter(prefix="/api/v1")
main_router.include_router(admin.router) main_router.include_router(admin.router)
main_router.include_router(user.router) # main_router.include_router(user.router)
__all__ = ["main_router"] __all__ = ["main_router"]

View file

@ -5,11 +5,9 @@ Common dependencies for routers.
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Capability, User
from ..db.schemata import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
@ -56,7 +54,6 @@ class Responses:
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session | None = Depends(Connection.get),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
) -> User | None: ) -> User | None:
""" """
@ -68,13 +65,11 @@ async def get_current_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
username = await current_config.jwt.decode_token(token) username = await current_config.jwt.decode_token(token)
user = User.from_db(db, username)
return user return User.get(username)
async def get_current_user_if_exists( async def get_current_user_if_exists(
current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
) -> User: ) -> User:
""" """
@ -89,7 +84,6 @@ async def get_current_user_if_exists(
async def get_current_user_if_admin( async def get_current_user_if_admin(
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -97,7 +91,7 @@ async def get_current_user_if_admin(
""" """
# fail if not requested by an admin # fail if not requested by an admin
if not current_user.is_admin(): if not current_user.can(Capability.admin):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -105,7 +99,6 @@ async def get_current_user_if_admin(
async def get_current_user_if_admin_or_self( async def get_current_user_if_admin_or_self(
user_name: str, user_name: str,
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -116,7 +109,8 @@ async def get_current_user_if_admin_or_self(
""" """
# fail if not requested by an admin or self # fail if not requested by an admin or self
if not (current_user.is_admin() or current_user.name == user_name): if not (current_user.can(Capability.admin)
or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user

View file

@ -6,9 +6,8 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Capability, Connection, User, UserCreate
from ..db.schemata import User, UserCapability, UserCreate from ._common import Responses, get_current_user_if_admin
from ._common import Responses, get_current_user
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@ -22,7 +21,7 @@ router = APIRouter(prefix="/admin", tags=["admin"])
) )
async def install( async def install(
config: Config, config: Config,
admin_user: UserCreate, # admin_user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
""" """
@ -35,18 +34,13 @@ async def install(
# create config file, connect to database # create config file, connect to database
await config.save() await config.save()
Connection.connect(await config.db.db_engine) Connection.connect("sqlite:///tmp/v2.db")
# create an administrative user # # create an administrative user
with Connection.use() as db: # new_user = User.create(**admin_user)
new_user = User.create( # assert new_user is not None
db=db, # new_user.set_capabilities([Capability.login, Capability.admin])
user=admin_user, # new_user.update()
crypt_context=await config.crypto.crypt_context,
)
new_user.capabilities.append(UserCapability.admin)
new_user.update(db)
@router.put( @router.put(
@ -61,7 +55,7 @@ async def install(
async def set_config( async def set_config(
new_config: Config, new_config: Config,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user), _: User | None = Depends(get_current_user_if_admin),
): ):
""" """
PUT ./config: Edit `kiwi-vpn` main config. PUT ./config: Edit `kiwi-vpn` main config.
@ -71,11 +65,6 @@ async def set_config(
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None
or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database # update config file, reconnect to database
await new_config.save() await new_config.save()
Connection.connect(await new_config.db.db_engine) Connection.connect("sqlite:///tmp/v2.db")

View file

@ -1,88 +0,0 @@
"""
/dn endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..db import Connection
from ..db.schemata import DistinguishedName, DistinguishedNameCreate, User
from ._common import Responses, get_current_user_if_admin_or_self
router = APIRouter(prefix="/dn")
@router.post(
"",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
)
async def add_distinguished_name(
user_name: str,
distinguished_name: DistinguishedNameCreate,
_: User = Depends(get_current_user_if_admin_or_self),
db: Session | None = Depends(Connection.get),
):
"""
POST ./: Create a new distinguished name in the database.
"""
owner = User.from_db(
db=db,
name=user_name,
)
# fail if owner doesn't exist
if owner is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# actually create the new user
new_dn = DistinguishedName.create(
db=db,
dn=distinguished_name,
owner=owner,
)
# fail if creation was unsuccessful
if new_dn is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# return the created user on success
return new_dn
# @router.delete(
# "",
# responses={
# status.HTTP_200_OK: Responses.OK,
# status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
# status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
# status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
# status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
# },
# )
# async def remove_distinguished_name(
# user_name: str,
# _: User = Depends(get_current_user_if_admin),
# db: Session | None = Depends(Connection.get),
# ):
# """
# DELETE ./{user_name}: Remove a user from the database.
# """
# # get the user
# user = User.from_db(
# db=db,
# name=user_name,
# )
# # fail if deletion was unsuccessful
# if user is None or not user.delete(db):
# raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)