diff --git a/api/.vscode/settings.json b/api/.vscode/settings.json index cd32962..2607a64 100644 --- a/api/.vscode/settings.json +++ b/api/.vscode/settings.json @@ -11,5 +11,6 @@ "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": true - } + }, + "git.closeDiffOnOperation": true } \ No newline at end of file diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 4b2cf24..2ba228a 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -20,8 +20,6 @@ from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field, validator -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine class Settings(BaseSettings): @@ -31,18 +29,21 @@ class Settings(BaseSettings): production_mode: bool = False data_dir: Path = Path("./tmp") + config_file_name: Path = Path("config.json") + api_v1_prefix: str = "api/v1" openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = "/redoc" - @staticmethod + @classmethod + @property @functools.lru_cache - def get() -> Settings: - return Settings() + def _(cls) -> Settings: + return cls() @property def config_file(self) -> Path: - return self.data_dir.joinpath("config.json") + return self.data_dir.joinpath(self.config_file_name) class DBType(Enum): @@ -63,23 +64,20 @@ class DBConfig(BaseModel): user: str | None = None password: str | None = None host: str | None = None - database: str | None = Settings.get().data_dir.joinpath("vpn.db") + database: str | None = Settings._.data_dir.joinpath("vpn.db") mysql_driver: str = "pymysql" mysql_args: list[str] = ["charset=utf8mb4"] @property - async def db_engine(self) -> Engine: + def uri(self) -> str: """ - Construct an SQLAlchemy engine + Construct a database connection string """ if self.type is DBType.sqlite: # SQLite backend - return create_engine( - f"sqlite:///{self.database}", - connect_args={"check_same_thread": False}, - ) + return f"sqlite:///{self.database}" elif self.type is DBType.mysql: # MySQL backend @@ -88,12 +86,9 @@ class DBConfig(BaseModel): else: args_str = "" - return create_engine( - f"mysql+{self.mysql_driver}://" - f"{self.user}:{self.password}@{self.host}" - f"/{self.database}{args_str}", - pool_recycle=3600, - ) + return (f"mysql+{self.mysql_driver}://" + f"{self.user}:{self.password}@{self.host}" + f"/{self.database}{args_str}") class JWTConfig(BaseModel): @@ -181,7 +176,7 @@ class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] @property - async def crypt_context(self) -> CryptContext: + def crypt_context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", @@ -197,23 +192,38 @@ class Config(BaseModel): jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) - @staticmethod - async def load() -> Config | None: + __singleton: Config | None = None + + @classmethod + def load(cls) -> Config | None: """ Load configuration from config file """ + if cls.__singleton is not None: + return cls.__singleton + try: - with open(Settings.get().config_file, "r") as config_file: - return Config.parse_obj(json.load(config_file)) + with open(Settings._.config_file, "r") as config_file: + cls.__singleton = Config.parse_obj(json.load(config_file)) + return cls.__singleton except FileNotFoundError: return None - async def save(self) -> None: + @classmethod + @property + def _(cls) -> Config | None: + """ + Shorthand for load() + """ + + return cls.load() + + def save(self) -> None: """ Save configuration to config file """ - with open(Settings.get().config_file, "w") as config_file: + with open(Settings._.config_file, "w") as config_file: config_file.write(self.json(indent=2)) diff --git a/api/kiwi_vpn_api/db/__init__.py b/api/kiwi_vpn_api/db/__init__.py index 0e8041b..aa0cb3c 100644 --- a/api/kiwi_vpn_api/db/__init__.py +++ b/api/kiwi_vpn_api/db/__init__.py @@ -1,4 +1,20 @@ -from . import models, schemas -from .connection import Connection +""" +Package `db`: ORM and schemas for database content. +""" -__all__ = ["Connection", "models", "schemas"] +from .connection import Connection +from .device import Device, DeviceBase, DeviceCreate +from .user import User, UserBase, UserCreate, UserRead +from .user_capability import UserCapabilityType + +__all__ = [ + "Connection", + "Device", + "DeviceBase", + "DeviceCreate", + "User", + "UserBase", + "UserCreate", + "UserRead", + "UserCapabilityType", +] diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py index 97592e1..5826daf 100644 --- a/api/kiwi_vpn_api/db/connection.py +++ b/api/kiwi_vpn_api/db/connection.py @@ -1,75 +1,34 @@ """ -Utilities for handling SQLAlchemy database connections. +Database connection management. """ -from typing import Generator - -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session, sessionmaker - -from .models import ORMBaseModel - - -class SessionManager: - """ - Simple context manager for an ORM session. - """ - - __session: Session - - def __init__(self, session: Session) -> None: - self.__session = session - - def __enter__(self) -> Session: - return self.__session - - def __exit__(self, *args) -> None: - self.__session.close() +from sqlmodel import Session, SQLModel, create_engine class Connection: """ - Namespace for the database connection. + Namespace for the database connection """ - engine: Engine | None = None - session_local: sessionmaker | None = None + engine = None @classmethod - def connect(cls, engine: Engine) -> None: + def connect(cls, connection_url: str) -> None: """ Connect ORM to a database engine. """ - cls.engine = engine - cls.session_local = sessionmaker( - autocommit=False, autoflush=False, bind=engine, - ) - ORMBaseModel.metadata.create_all(bind=engine) + cls.engine = create_engine(connection_url) + SQLModel.metadata.create_all(cls.engine) @classmethod - def use(cls) -> SessionManager | None: + @property + def session(cls) -> Session | None: """ Create an ORM session using a context manager. """ - if cls.session_local is None: + if cls.engine is None: return None - return SessionManager(cls.session_local()) - - @classmethod - async def get(cls) -> Generator[Session | None, None, None]: - """ - Create an ORM session using a FastAPI compatible async generator. - """ - - if cls.session_local is None: - yield None - - else: - db = cls.session_local() - try: - yield db - finally: - db.close() + return Session(cls.engine) diff --git a/api/kiwi_vpn_api/db/device.py b/api/kiwi_vpn_api/db/device.py new file mode 100644 index 0000000..ff791e0 --- /dev/null +++ b/api/kiwi_vpn_api/db/device.py @@ -0,0 +1,101 @@ +""" +Python representation of `device` table. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING + +from sqlalchemy.exc import IntegrityError +from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint + +from .connection import Connection + +if TYPE_CHECKING: + from .user import User + + +class DeviceBase(SQLModel): + """ + Common to all representations of devices + """ + + name: str + type: str + expiry: datetime | None + + +class DeviceCreate(DeviceBase): + """ + Representation of a newly created device + """ + + owner_name: str | None + + +class DeviceRead(DeviceBase): + """ + Representation of a device read via the API + """ + + owner_name: str | None + + +class Device(DeviceBase, table=True): + """ + Representation of `device` table + """ + + __table_args__ = (UniqueConstraint( + "owner_name", + "name", + ),) + + id: int | None = Field(primary_key=True) + owner_name: str | None = Field(foreign_key="user.name") + + # no idea, but "User" (in quotes) doesn't work here + # might be a future problem? + owner: User = Relationship( + back_populates="devices", + ) + + @classmethod + def create(cls, **kwargs) -> Device | None: + """ + Create a new device in the database. + """ + + try: + with Connection.session as db: + device = cls.from_orm(DeviceCreate(**kwargs)) + + db.add(device) + db.commit() + db.refresh(device) + + return device + + except IntegrityError: + # device already existed + return None + + def update(self) -> None: + """ + Update this device in the database. + """ + + with Connection.session as db: + db.add(self) + db.commit() + db.refresh(self) + + def delete(self) -> bool: + """ + Delete this device from the database. + """ + + with Connection.session as db: + db.delete(self) + db.commit() diff --git a/api/kiwi_vpn_api/db/models.py b/api/kiwi_vpn_api/db/models.py deleted file mode 100644 index f9b6afa..0000000 --- a/api/kiwi_vpn_api/db/models.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -SQLAlchemy representation of database contents. -""" - -from __future__ import annotations - -import datetime - -from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, - UniqueConstraint) -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session, relationship - -ORMBaseModel = declarative_base() - - -class User(ORMBaseModel): - __tablename__ = "users" - - name = Column(String, primary_key=True, index=True) - password = Column(String, nullable=False) - - capabilities: list[UserCapability] = relationship( - "UserCapability", lazy="joined", cascade="all, delete-orphan" - ) - certificates: list[Certificate] = relationship( - "Certificate", lazy="select", back_populates="owner" - ) - distinguished_names: list[DistinguishedName] = relationship( - "DistinguishedName", lazy="select", back_populates="owner" - ) - - @classmethod - def load(cls, db: Session, name: str) -> User | None: - """ - Load user from database by name. - """ - - return (db - .query(User) - .filter(User.name == name) - .first()) - - -class UserCapability(ORMBaseModel): - __tablename__ = "user_capabilities" - - user_name = Column( - String, - ForeignKey("users.name"), - primary_key=True, - index=True, - ) - capability = Column(String, primary_key=True) - - -class DistinguishedName(ORMBaseModel): - __tablename__ = "distinguished_names" - - id = Column(Integer, primary_key=True, autoincrement=True) - - owner_name = Column(String, ForeignKey("users.name")) - cn_only = Column(Boolean, default=True, nullable=False) - country = Column(String(2)) - state = Column(String) - city = Column(String) - organization = Column(String) - organizational_unit = Column(String) - email = Column(String) - common_name = Column(String, nullable=False) - - owner: User = relationship( - "User", lazy="joined", back_populates="distinguished_names" - ) - - UniqueConstraint( - country, - state, - city, - organization, - organizational_unit, - email, - common_name, - ) - - -class Certificate(ORMBaseModel): - __tablename__ = "certificates" - - id = Column(Integer, primary_key=True, autoincrement=True) - - owner_name = Column(String, ForeignKey("users.name")) - dn_id = Column( - Integer, - ForeignKey("distinguished_names.id"), - nullable=False, - ) - expiry = Column(DateTime, default=datetime.datetime.now) - - distinguished_name: DistinguishedName = relationship( - "DistinguishedName", lazy="joined" - ) - - owner: User = relationship( - "User", lazy="joined", back_populates="certificates" - ) diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemas.py deleted file mode 100644 index 132aab1..0000000 --- a/api/kiwi_vpn_api/db/schemas.py +++ /dev/null @@ -1,275 +0,0 @@ -""" -Pydantic representation of database contents. -""" - - -from __future__ import annotations - -from datetime import datetime -from enum import Enum -from typing import Any - -from passlib.context import CryptContext -from pydantic import BaseModel, Field, constr, validator -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session - -from . import models - -########## -# table: distinguished_names -########## - - -class DistinguishedNameBase(BaseModel): - cn_only: bool - country: constr(max_length=2) | None - state: str | None - city: str | None - organization: str | None - organizational_unit: str | None - email: str | None - common_name: str - - -class DistinguishedNameCreate(DistinguishedNameBase): - pass - - -class DistinguishedName(DistinguishedNameBase): - class Config: - orm_mode = True - - @classmethod - def create( - cls, - db: Session, - dn: DistinguishedNameCreate, - owner: User, - ) -> User | None: - """ - Create a new distinguished name in the database. - """ - - try: - db_owner = models.User.load( - db=db, - name=owner.name, - ) - - dn = models.DistinguishedName( - cn_only=dn.cn_only, - country=dn.country, - state=dn.state, - city=dn.city, - organization=dn.organization, - organizational_unit=dn.organizational_unit, - email=dn.email, - common_name=dn.common_name, - owner=db_owner, - ) - - db.add(dn) - db.commit() - db.refresh(dn) - - return cls.from_orm(dn) - - except IntegrityError: - # distinguished name already existed - pass - -########## -# table: certificates -########## - - -class CertificateBase(BaseModel): - expiry: datetime - - -class CertificateCreate(CertificateBase): - pass - - -class Certificate(CertificateBase): - distinguished_name: DistinguishedName - - class Config: - orm_mode = True - -########## -# table: user_capabilities -########## - - -class UserCapability(Enum): - admin = "admin" - - def __repr__(self) -> str: - return self.value - - @classmethod - def from_value(cls, value) -> UserCapability: - """ - Create UserCapability from various formats - """ - - if isinstance(value, cls): - # value is already a UserCapability, use that - return value - - elif isinstance(value, models.UserCapability): - # create from db format - return cls(value.capability) - - else: - # create from string representation - return cls(str(value)) - -########## -# table: users -########## - - -class UserBase(BaseModel): - name: str - - -class UserCreate(UserBase): - password: str - - -class User(UserBase): - capabilities: list[UserCapability] = [] - - distinguished_names: list[DistinguishedName] = Field( - default=[], repr=False - ) - - certificates: list[Certificate] = Field( - default=[], repr=False - ) - - class Config: - orm_mode = True - - @validator("capabilities", pre=True) - @classmethod - def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]: - """ - Import the capabilities from various formats - """ - - return [ - UserCapability.from_value(capability) - for capability in value - ] - - @classmethod - def from_db( - cls, - db: Session, - name: str, - ) -> User | None: - """ - Load user from database by name. - """ - - if (db_user := models.User.load(db, name)) is None: - return None - - return cls.from_orm(db_user) - - @classmethod - def create( - cls, - db: Session, - user: UserCreate, - crypt_context: CryptContext, - ) -> User | None: - """ - Create a new user in the database. - """ - - try: - user = models.User( - name=user.name, - password=crypt_context.hash(user.password), - capabilities=[], - ) - - db.add(user) - db.commit() - db.refresh(user) - - return cls.from_orm(user) - - except IntegrityError: - # user already existed - pass - - def is_admin(self) -> bool: - return UserCapability.admin in self.capabilities - - def authenticate( - self, - db: Session, - password: str, - crypt_context: CryptContext, - ) -> User | None: - """ - Authenticate with name/password against users in database. - """ - - if (db_user := models.User.load(db, self.name)) is None: - # nonexistent user, fake doing password verification - crypt_context.dummy_verify() - return False - - if not crypt_context.verify(password, db_user.password): - # password hash mismatch - return False - - self.from_orm(db_user) - - return True - - def update( - self, - db: Session, - ) -> None: - """ - Update this user in the database. - """ - - old_dbuser = models.User.load(db, self.name) - old_user = self.from_orm(old_dbuser) - - for capability in self.capabilities: - if capability not in old_user.capabilities: - old_dbuser.capabilities.append( - models.UserCapability(capability=capability.value) - ) - - for capability in old_dbuser.capabilities: - if UserCapability.from_value(capability) not in self.capabilities: - db.delete(capability) - - db.commit() - - def delete( - self, - db: Session, - ) -> bool: - """ - Delete this user from the database. - """ - - if (db_user := models.User.load(db, self.name)) is None: - # nonexistent user - return False - - db.delete(db_user) - db.commit() - return True diff --git a/api/kiwi_vpn_api/db/user.py b/api/kiwi_vpn_api/db/user.py new file mode 100644 index 0000000..5e1a4bb --- /dev/null +++ b/api/kiwi_vpn_api/db/user.py @@ -0,0 +1,199 @@ +""" +Python representation of `user` table. +""" + +from __future__ import annotations + +from typing import Any, Sequence + +from pydantic import root_validator +from sqlalchemy.exc import IntegrityError +from sqlmodel import Field, Relationship, SQLModel + +from ..config import Config +from .connection import Connection +from .device import Device +from .user_capability import UserCapability, UserCapabilityType + + +class UserBase(SQLModel): + """ + Common to all representations of users + """ + + name: str = Field(primary_key=True) + email: str | None = Field(default=None) + + country: str | None = Field(default=None) + state: str | None = Field(default=None) + city: str | None = Field(default=None) + organization: str | None = Field(default=None) + organizational_unit: str | None = Field(default=None) + + +class UserCreate(UserBase): + """ + Representation of a newly created user + """ + + password: str | None = Field(default=None) + password_clear: str | None = Field(default=None) + + @root_validator + @classmethod + def hash_password(cls, values: dict[str, Any]) -> dict[str, Any]: + """ + Ensure the `password` value of this user gets set. + """ + + if (values.get("password")) is not None: + # password is set + return values + + if (password_clear := values.get("password_clear")) is None: + raise ValueError("No password to hash") + + if (current_config := Config._) is None: + raise ValueError("Not configured") + + values["password"] = current_config.crypto.crypt_context.hash( + password_clear) + + return values + + +class UserRead(UserBase): + """ + Representation of a user read via the API + """ + + pass + + +class User(UserBase, table=True): + """ + Representation of `user` table + """ + + password: str + + capabilities: list[UserCapability] = Relationship( + back_populates="user", + sa_relationship_kwargs={ + "lazy": "joined", + "cascade": "all, delete-orphan", + }, + ) + + devices: list[Device] = Relationship( + back_populates="owner", + ) + + @classmethod + def create(cls, **kwargs) -> User | None: + """ + Create a new user in the database. + """ + + try: + with Connection.session as db: + user = cls.from_orm(UserCreate(**kwargs)) + + db.add(user) + db.commit() + db.refresh(user) + + return user + + except IntegrityError: + # user already existed + return None + + @classmethod + def get(cls, name: str) -> User | None: + """ + Load user from database by name. + """ + + with Connection.session as db: + return db.get(cls, name) + + @classmethod + def authenticate( + cls, + name: str, + password: str, + ) -> User | None: + """ + Authenticate with name/password against users in database. + """ + + crypt_context = Config._.crypto.crypt_context + + if (user := cls.get(name)) is None: + # nonexistent user, fake doing password verification + crypt_context.dummy_verify() + return None + + if not crypt_context.verify(password, user.password): + # password hash mismatch + return None + + return user + + def update(self) -> None: + """ + Update this user in the database. + """ + + with Connection.session as db: + db.add(self) + db.commit() + db.refresh(self) + + def delete(self) -> None: + """ + Delete this user from the database. + """ + + with Connection.session as db: + db.delete(self) + db.commit() + + def get_capabilities(self) -> set[UserCapabilityType]: + """ + Return the capabilities of this user. + """ + + return set( + capability._ + for capability in self.capabilities + ) + + def can( + self, + capability: UserCapabilityType, + ) -> bool: + """ + Check if this user has a capability. + """ + + return ( + capability in self.get_capabilities() + or UserCapabilityType.admin in self.get_capabilities() + ) + + def set_capabilities( + self, + capabilities: Sequence[UserCapabilityType], + ) -> None: + """ + Change the capabilities of this user. + """ + + self.capabilities = [ + UserCapability( + user_name=self.name, + capability_name=capability.value, + ) for capability in capabilities + ] diff --git a/api/kiwi_vpn_api/db/user_capability.py b/api/kiwi_vpn_api/db/user_capability.py new file mode 100644 index 0000000..479fec4 --- /dev/null +++ b/api/kiwi_vpn_api/db/user_capability.py @@ -0,0 +1,56 @@ +""" +Python representation of `user_capability` table. +""" + +from enum import Enum +from typing import TYPE_CHECKING + +from sqlmodel import Field, Relationship, SQLModel + +if TYPE_CHECKING: + from .user import User + + +class UserCapabilityType(Enum): + """ + Allowed values for capabilities + """ + + admin = "admin" + login = "login" + issue = "issue" + renew = "renew" + + def __repr__(self) -> str: + return self.value + + +class UserCapabilityBase(SQLModel): + """ + Common to all representations of capabilities + """ + + capability_name: str = Field(primary_key=True) + + @property + def _(self) -> UserCapabilityType: + """ + Transform into a `Capability`. + """ + + return UserCapabilityType(self.capability_name) + + def __repr__(self) -> str: + return self.capability_name + + +class UserCapability(UserCapabilityBase, table=True): + """ + Representation of `user_capability` table + """ + + user_name: str = Field(primary_key=True, foreign_key="user.name") + + user: "User" = Relationship( + back_populates="capabilities", + ) diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 82c2fa8..86ae4f2 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -13,13 +13,9 @@ import uvicorn from fastapi import FastAPI from .config import Config, Settings -from .db import Connection -from .db.schemas import User +from .db import Connection, User from .routers import main_router -settings = Settings.get() - - app = FastAPI( title="kiwi-vpn API", description="This API enables the `kiwi-vpn` service.", @@ -31,30 +27,29 @@ app = FastAPI( "name": "MIT License", "url": "https://opensource.org/licenses/mit-license.php", }, - openapi_url=settings.openapi_url, - docs_url=settings.docs_url if not settings.production_mode else None, - redoc_url=settings.redoc_url if not settings.production_mode else None, + openapi_url=Settings._.openapi_url, + docs_url=Settings._.docs_url if not Settings._.production_mode else None, + redoc_url=Settings._.redoc_url if not Settings._.production_mode else None, ) -app.include_router(main_router) +app.include_router(main_router, prefix=f"/{Settings._.api_v1_prefix}") @app.on_event("startup") async def on_startup() -> None: # check if configured - if (current_config := await Config.load()) is not None: + if (current_config := Config._) is not None: # connect to database - Connection.connect(await current_config.db.db_engine) + Connection.connect(current_config.db.uri) # some testing - with Connection.use() as db: - print(User.from_db(db, "admin")) - print(User.from_db(db, "nonexistent")) + print(User.get("admin")) + print(User.get("nonexistent")) def main() -> None: uvicorn.run( - "kiwi_vpn_api.main:app", + app="kiwi_vpn_api.main:app", host="0.0.0.0", port=8000, reload=True, diff --git a/api/kiwi_vpn_api/routers/__init__.py b/api/kiwi_vpn_api/routers/__init__.py index 22bb142..6cb693f 100644 --- a/api/kiwi_vpn_api/routers/__init__.py +++ b/api/kiwi_vpn_api/routers/__init__.py @@ -1,10 +1,18 @@ +""" +Package `routers`: Each module contains the path operations for their prefixes. + +This file: Main API router definition. +""" + from fastapi import APIRouter from . import admin, user -main_router = APIRouter(prefix="/api/v1") +main_router = APIRouter() main_router.include_router(admin.router) main_router.include_router(user.router) -__all__ = ["main_router"] +__all__ = [ + "main_router", +] diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index f9424bc..ba8c496 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -5,13 +5,13 @@ Common dependencies for routers. from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.orm import Session -from ..config import Config -from ..db import Connection -from ..db.schemas import User +from ..config import Config, Settings +from ..db import User, UserCapabilityType -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" +) class Responses: @@ -56,7 +56,6 @@ class Responses: async def get_current_user( token: str = Depends(oauth2_scheme), - db: Session | None = Depends(Connection.get), current_config: Config | None = Depends(Config.load), ) -> User | None: """ @@ -68,13 +67,11 @@ async def get_current_user( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) username = await current_config.jwt.decode_token(token) - user = User.from_db(db, username) - return user + return User.get(username) async def get_current_user_if_exists( - current_config: Config | None = Depends(Config.load), current_user: User | None = Depends(get_current_user), ) -> User: """ @@ -89,7 +86,6 @@ async def get_current_user_if_exists( async def get_current_user_if_admin( - current_config: Config | None = Depends(Config.load), current_user: User = Depends(get_current_user_if_exists), ) -> User: """ @@ -97,7 +93,7 @@ async def get_current_user_if_admin( """ # fail if not requested by an admin - if not current_user.is_admin(): + if not current_user.can(UserCapabilityType.admin): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return current_user @@ -105,7 +101,6 @@ async def get_current_user_if_admin( async def get_current_user_if_admin_or_self( user_name: str, - current_config: Config | None = Depends(Config.load), current_user: User = Depends(get_current_user_if_exists), ) -> User: """ @@ -116,7 +111,8 @@ async def get_current_user_if_admin_or_self( """ # fail if not requested by an admin or self - if not (current_user.is_admin() or current_user.name == user_name): + if not (current_user.can(UserCapabilityType.admin) + or current_user.name == user_name): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) return current_user diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index f3e6daf..ca0ad6e 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -4,49 +4,68 @@ from fastapi import APIRouter, Depends, HTTPException, status +from sqlmodel import select from ..config import Config -from ..db import Connection -from ..db.schemas import User, UserCapability, UserCreate -from ._common import Responses, get_current_user +from ..db import Connection, User, UserCapabilityType, UserCreate +from ._common import Responses, get_current_user_if_admin router = APIRouter(prefix="/admin", tags=["admin"]) @router.put( - "/install", + "/install/config", responses={ status.HTTP_200_OK: Responses.OK, status.HTTP_400_BAD_REQUEST: Responses.INSTALLED, }, ) -async def install( +async def initial_configure( config: Config, - admin_user: UserCreate, current_config: Config | None = Depends(Config.load), ): """ - PUT ./install: Install `kiwi-vpn`. + PUT ./install/config: Configure `kiwi-vpn`. """ - # fail if already installed + # fail if already configured if current_config is not None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) # create config file, connect to database - await config.save() - Connection.connect(await config.db.db_engine) + config.save() + Connection.connect(current_config.db.uri) + + +@router.put( + "/install/admin", + responses={ + status.HTTP_200_OK: Responses.OK, + status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED, + status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS, + }, +) +async def create_initial_admin( + admin_user: UserCreate, + current_config: Config | None = Depends(Config.load), +): + """ + PUT ./install/admin: Create the first administrative user. + """ + + # fail if not configured + if current_config is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + + # fail if any user exists + with Connection.session as db: + if db.exec(select(User).limit(1)).first() is not None: + raise HTTPException(status_code=status.HTTP_409_CONFLICT) # create an administrative user - with Connection.use() as db: - new_user = User.create( - db=db, - user=admin_user, - crypt_context=await config.crypto.crypt_context, - ) - - new_user.capabilities.append(UserCapability.admin) - new_user.update(db) + new_user = User.create(**admin_user.dict()) + new_user.set_capabilities([UserCapabilityType.admin]) + new_user.update() @router.put( @@ -59,23 +78,13 @@ async def install( }, ) async def set_config( - new_config: Config, - current_config: Config | None = Depends(Config.load), - current_user: User | None = Depends(get_current_user), + config: Config, + _: User = Depends(get_current_user_if_admin), ): """ PUT ./config: Edit `kiwi-vpn` main config. """ - # fail if not installed - if current_config is None: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) - - # fail if not requested by an admin - if (current_user is None - or UserCapability.admin not in current_user.capabilities): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) - # update config file, reconnect to database - await new_config.save() - Connection.connect(await new_config.db.db_engine) + config.save() + Connection.connect(config.db.uri) diff --git a/api/kiwi_vpn_api/routers/dn.py b/api/kiwi_vpn_api/routers/dn.py deleted file mode 100644 index ae1eae2..0000000 --- a/api/kiwi_vpn_api/routers/dn.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -/dn endpoints. -""" - - -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy.orm import Session - -from ..db import Connection -from ..db.schemas import DistinguishedName, DistinguishedNameCreate, User -from ._common import Responses, get_current_user_if_admin_or_self - -router = APIRouter(prefix="/dn") - - -@router.post( - "", - responses={ - status.HTTP_200_OK: Responses.OK, - status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED, - status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, - status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, - status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST, - status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS, - }, -) -async def add_distinguished_name( - user_name: str, - distinguished_name: DistinguishedNameCreate, - _: User = Depends(get_current_user_if_admin_or_self), - db: Session | None = Depends(Connection.get), -): - """ - POST ./: Create a new distinguished name in the database. - """ - - owner = User.from_db( - db=db, - name=user_name, - ) - - # fail if owner doesn't exist - if owner is None: - raise HTTPException(status_code=status.HTTP_409_CONFLICT) - - # actually create the new user - new_dn = DistinguishedName.create( - db=db, - dn=distinguished_name, - owner=owner, - ) - - # fail if creation was unsuccessful - if new_dn is None: - raise HTTPException(status_code=status.HTTP_409_CONFLICT) - - # return the created user on success - return new_dn diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 63e2b22..1db468c 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -5,11 +5,9 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel -from sqlalchemy.orm import Session from ..config import Config -from ..db import Connection -from ..db.schemas import User, UserCapability, UserCreate +from ..db import User, UserCapabilityType, UserCreate, UserRead from ._common import Responses, get_current_user, get_current_user_if_admin router = APIRouter(prefix="/user", tags=["user"]) @@ -28,7 +26,6 @@ class Token(BaseModel): async def login( form_data: OAuth2PasswordRequestForm = Depends(), current_config: Config | None = Depends(Config.load), - db: Session | None = Depends(Connection.get), ): """ POST ./authenticate: Authenticate a user. Issues a bearer token. @@ -39,12 +36,10 @@ async def login( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) # try logging in - user = User(name=form_data.username) - if not user.authenticate( - db=db, + if not (user := User.authenticate( + name=form_data.username, password=form_data.password, - crypt_context=await current_config.crypto.crypt_context, - ): + )): # authentication failed raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -52,12 +47,16 @@ async def login( headers={"WWW-Authenticate": "Bearer"}, ) + if not user.can(UserCapabilityType.login): + # user cannot login + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + # authentication succeeded access_token = await current_config.jwt.create_token(user.name) return {"access_token": access_token, "token_type": "bearer"} -@router.get("/current", response_model=User) +@router.get("/current", response_model=UserRead) async def get_current_user( current_user: User | None = Depends(get_current_user), ): @@ -81,20 +80,15 @@ async def get_current_user( ) async def add_user( user: UserCreate, - current_config: Config | None = Depends(Config.load), _: User = Depends(get_current_user_if_admin), - db: Session | None = Depends(Connection.get), ): """ POST ./: Create a new user in the database. """ # actually create the new user - new_user = User.create( - db=db, - user=user, - crypt_context=await current_config.crypto.crypt_context, - ) + new_user = User.create(**user.dict()) + new_user.set_capabilities([UserCapabilityType.login]) # fail if creation was unsuccessful if new_user is None: @@ -118,22 +112,21 @@ async def add_user( async def remove_user( user_name: str, _: User = Depends(get_current_user_if_admin), - db: Session | None = Depends(Connection.get), ): """ DELETE ./{user_name}: Remove a user from the database. """ # get the user - user = User.from_db( - db=db, - name=user_name, - ) + user = User.get(user_name) - # fail if deletion was unsuccessful - if user is None or not user.delete(db): + # fail if user not found + if user is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) + # delete user + user.delete() + @router.post( "/{user_name}/capabilities", @@ -146,22 +139,19 @@ async def remove_user( ) async def extend_capabilities( user_name: str, - capabilities: list[UserCapability], + capabilities: list[UserCapabilityType], _: User = Depends(get_current_user_if_admin), - db: Session | None = Depends(Connection.get), ): """ POST ./{user_name}/capabilities: Add capabilities to a user. """ # get and change the user - user = User.from_db( - db=db, - name=user_name, - ) + user = User.get(user_name) - user.capabilities.extend(capabilities) - user.update(db) + user.set_capabilities(user.get_capabilities() | set(capabilities)) + + user.update() @router.delete( @@ -175,21 +165,16 @@ async def extend_capabilities( ) async def remove_capabilities( user_name: str, - capabilities: list[UserCapability], + capabilities: list[UserCapabilityType], _: User = Depends(get_current_user_if_admin), - db: Session | None = Depends(Connection.get), ): """ DELETE ./{user_name}/capabilities: Remove capabilities from a user. """ # get and change the user - user = User.from_db( - db=db, - name=user_name, - ) + user = User.get(user_name) - for capability in capabilities: - user.capabilities.remove(capability) + user.set_capabilities(user.get_capabilities() - set(capabilities)) - user.update(db) + user.update() diff --git a/api/plan.md b/api/plan.md new file mode 100644 index 0000000..0822af0 --- /dev/null +++ b/api/plan.md @@ -0,0 +1,25 @@ +## Server props +- default DN parts: country, state, city, org, OU +- "customizable" flags for DN parts +- flag: use client-to-client +- force cipher, tls-cipher, auth params +- server name +- default certification length +- default certificate algo + +## User props +- username +- password +- custom DN parts: country, state, city, org, OU +- email + +## User caps +- admin: administrator +- login: can log into the web interface +- issue: can certify own devices without approval +- renew: can renew certificates for own devices + +## Device props +- name +- type (icon) +- expiry diff --git a/api/poetry.lock b/api/poetry.lock index 3f5e84d..b778921 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -428,6 +428,30 @@ postgresql_psycopg2cffi = ["psycopg2cffi"] pymysql = ["pymysql (<1)", "pymysql"] sqlcipher = ["sqlcipher3-binary"] +[[package]] +name = "sqlalchemy2-stubs" +version = "0.0.2a21" +description = "Typing Stubs for SQLAlchemy 1.4" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +typing-extensions = ">=3.7.4" + +[[package]] +name = "sqlmodel" +version = "0.0.6" +description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness." +category = "main" +optional = false +python-versions = ">=3.6.1,<4.0.0" + +[package.dependencies] +pydantic = ">=1.8.2,<2.0.0" +SQLAlchemy = ">=1.4.17,<1.5.0" +sqlalchemy2-stubs = "*" + [[package]] name = "starlette" version = "0.17.1" @@ -477,7 +501,7 @@ standard = ["websockets (>=10.0)", "httptools (>=0.4.0)", "watchgod (>=0.6)", "p [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "432d2933102f8a0091cec1b5484944a0211ca74c5dc9b65877d99d7bd160e4bb" +content-hash = "a580f9fe4c68667e4cbdf385ac11d5c7a2925e3c990b7164faa922ec8b6f9555" [metadata.files] anyio = [ @@ -834,6 +858,14 @@ sqlalchemy = [ {file = "SQLAlchemy-1.4.32-cp39-cp39-win_amd64.whl", hash = "sha256:b3f1d9b3aa09ab9adc7f8c4b40fc3e081eb903054c9a6f9ae1633fe15ae503b4"}, {file = "SQLAlchemy-1.4.32.tar.gz", hash = "sha256:6fdd2dc5931daab778c2b65b03df6ae68376e028a3098eb624d0909d999885bc"}, ] +sqlalchemy2-stubs = [ + {file = "sqlalchemy2-stubs-0.0.2a21.tar.gz", hash = "sha256:207e3d8a36fc032d325f4eec89e0c6760efe81d07e978513d8c9b14f108dcd0c"}, + {file = "sqlalchemy2_stubs-0.0.2a21-py3-none-any.whl", hash = "sha256:bd4a3d5ca7ff9d01b2245e1b26304d6b2ec4daf43a01faf40db9e09245679433"}, +] +sqlmodel = [ + {file = "sqlmodel-0.0.6-py3-none-any.whl", hash = "sha256:c5fd8719e09da348cd32ce2a5b6a44f289d3029fa8f1c9818229b6f34f1201b4"}, + {file = "sqlmodel-0.0.6.tar.gz", hash = "sha256:3b4f966b9671b24d85529d274e6c4dbc7753b468e35d2d6a40bd75cad1f66813"}, +] starlette = [ {file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"}, {file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"}, diff --git a/api/pyproject.toml b/api/pyproject.toml index a105d62..a9a7855 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,6 +13,7 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"} passlib = {extras = ["argon2", "bcrypt"], version = "^1.7.4"} SQLAlchemy = "^1.4.32" pyOpenSSL = "^22.0.0" +sqlmodel = "^0.0.6" [tool.poetry.dev-dependencies] pytest = "^7.1.0"