diff --git a/api/kiwi_vpn_api/db/__init__.py b/api/kiwi_vpn_api/db/__init__.py index f418393..099af9b 100644 --- a/api/kiwi_vpn_api/db/__init__.py +++ b/api/kiwi_vpn_api/db/__init__.py @@ -1,7 +1,11 @@ -from .capability import Capability +""" +Package `db`: ORM and schemas for database content. +""" + from .connection import Connection from .device import Device, DeviceBase, DeviceCreate from .user import User, UserBase, UserCreate, UserRead +from .user_capability import Capability __all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate", "User", "UserBase", "UserCreate", "UserRead"] diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py index ff9341a..ed5d4c2 100644 --- a/api/kiwi_vpn_api/db/connection.py +++ b/api/kiwi_vpn_api/db/connection.py @@ -1,9 +1,13 @@ +""" +Database connection management +""" + from sqlmodel import Session, SQLModel, create_engine class Connection: """ - Namespace for the database connection. + Namespace for the database connection """ engine = None diff --git a/api/kiwi_vpn_api/db/device.py b/api/kiwi_vpn_api/db/device.py index f260cbc..2f9e3e1 100644 --- a/api/kiwi_vpn_api/db/device.py +++ b/api/kiwi_vpn_api/db/device.py @@ -1,3 +1,7 @@ +""" +Python representation of `device` table. +""" + from __future__ import annotations from datetime import datetime @@ -13,16 +17,36 @@ if TYPE_CHECKING: class DeviceBase(SQLModel): + """ + Common to all representations of devices + """ + name: str type: str expiry: datetime | None class DeviceCreate(DeviceBase): + """ + Representation of a newly created device + """ + + owner_name: str | None + + +class DeviceRead(DeviceBase): + """ + Representation of a device read via the API + """ + owner_name: str | None class Device(DeviceBase, table=True): + """ + Representation of device table + """ + __table_args__ = (UniqueConstraint( "owner_name", "name", diff --git a/api/kiwi_vpn_api/db/user.py b/api/kiwi_vpn_api/db/user.py index f40253b..cc5e079 100644 --- a/api/kiwi_vpn_api/db/user.py +++ b/api/kiwi_vpn_api/db/user.py @@ -1,3 +1,7 @@ +""" +Python representation of `user` table. +""" + from __future__ import annotations from typing import Any @@ -7,12 +11,16 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import Field, Relationship, SQLModel from ..config import Config -from .capability import Capability, UserCapability from .connection import Connection from .device import Device +from .user_capability import Capability, UserCapability class UserBase(SQLModel): + """ + Common to all representations of users + """ + name: str = Field(primary_key=True) email: str | None = Field(default=None) @@ -23,7 +31,50 @@ class UserBase(SQLModel): organizational_unit: str | None = Field(default=None) +class UserCreate(UserBase): + """ + Representation of a newly created user + """ + + password: str | None = Field(default=None) + password_clear: str | None = Field(default=None) + + @root_validator + @classmethod + def hash_password(cls, values: dict[str, Any]) -> dict[str, Any]: + """ + Ensure the `password` value of this user gets set. + """ + + if (values.get("password")) is not None: + # password is set + return values + + if (password_clear := values.get("password_clear")) is None: + raise ValueError("No password to hash") + + if (current_config := Config._) is None: + raise ValueError("Not configured") + + values["password"] = current_config.crypto.crypt_context.hash( + password_clear) + + return values + + +class UserRead(UserBase): + """ + Representation of a user read via the API + """ + + pass + + class User(UserBase, table=True): + """ + Representation of user table + """ + password: str capabilities: list[UserCapability] = Relationship( @@ -109,46 +160,31 @@ class User(UserBase, table=True): db.delete(self) db.commit() - def can(self, capability: Capability) -> bool: - return capability in self.get_capabilities() - def get_capabilities(self) -> set[Capability]: + """ + Return the capabilities of this user. + """ + return set( capability._ for capability in self.capabilities ) + def can(self, capability: Capability) -> bool: + """ + Check if this user has a capability. + """ + + return capability in self.get_capabilities() + def set_capabilities(self, capabilities: set[Capability]) -> None: + """ + Change the capabilities of this user. + """ + self.capabilities = [ UserCapability( user_name=self.name, capability_name=capability.value, ) for capability in capabilities ] - - -class UserCreate(UserBase): - password: str | None = Field(default=None) - password_clear: str | None = Field(default=None) - - @root_validator - @classmethod - def hash_password(cls, values: dict[str, Any]) -> dict[str, Any]: - if (values.get("password")) is not None: - # password is set - return values - - if (password_clear := values.get("password_clear")) is None: - raise ValueError("No password to hash") - - if (current_config := Config._) is None: - raise ValueError("Not configured") - - values["password"] = current_config.crypto.crypt_context.hash( - password_clear) - - return values - - -class UserRead(UserBase): - pass diff --git a/api/kiwi_vpn_api/db/capability.py b/api/kiwi_vpn_api/db/user_capability.py similarity index 71% rename from api/kiwi_vpn_api/db/capability.py rename to api/kiwi_vpn_api/db/user_capability.py index 9ed8aba..7b98eb7 100644 --- a/api/kiwi_vpn_api/db/capability.py +++ b/api/kiwi_vpn_api/db/user_capability.py @@ -1,3 +1,7 @@ +""" +Python representation of `usercapability` table. +""" + from enum import Enum from typing import TYPE_CHECKING @@ -8,6 +12,10 @@ if TYPE_CHECKING: class Capability(Enum): + """ + Allowed values for capabilities + """ + admin = "admin" login = "login" issue = "issue" @@ -18,10 +26,18 @@ class Capability(Enum): class UserCapabilityBase(SQLModel): + """ + Common to all representations of capabilities + """ + capability_name: str = Field(primary_key=True) @property def _(self) -> Capability: + """ + Transform into a `Capability`. + """ + return Capability(self.capability_name) def __repr__(self) -> str: @@ -29,6 +45,10 @@ class UserCapabilityBase(SQLModel): class UserCapability(UserCapabilityBase, table=True): + """ + Representation of usercapability table + """ + user_name: str = Field(primary_key=True, foreign_key="user.name") user: "User" = Relationship( diff --git a/api/kiwi_vpn_api/routers/__init__.py b/api/kiwi_vpn_api/routers/__init__.py index 1d19caa..0398ff7 100644 --- a/api/kiwi_vpn_api/routers/__init__.py +++ b/api/kiwi_vpn_api/routers/__init__.py @@ -1,3 +1,9 @@ +""" +Package `routers`: Each module contains the path operations for their prefixes. + +This file: Main API router definition. +""" + from fastapi import APIRouter from . import admin, user