""" Configuration definition. Converts per-run (environment) variables and config files into the "python world" using `pydantic`. Pydantic models might have convenience methods attached. """ from __future__ import annotations import functools import json from datetime import datetime, timedelta from enum import Enum from pathlib import Path from secrets import token_urlsafe from typing import Any from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, constr, validator class Settings(BaseSettings): """ Per-run settings """ production_mode: bool = False data_dir: Path = Path("./tmp") config_file_name: Path = Path("config.json") api_v1_prefix: str = "api/v1" openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = "/redoc" @classmethod @functools.lru_cache def load(cls) -> Settings: return cls() @classmethod @property def _(cls) -> Settings: """ Shorthand for load() """ return cls.load() @property def config_file(self) -> Path: return self.data_dir.joinpath(self.config_file_name) class DBType(Enum): """ Supported database types """ sqlite = "sqlite" mysql = "mysql" class DBConfig(BaseModel): """ Database connection configuration """ type: DBType = DBType.sqlite user: str | None = None password: str | None = None host: str | None = None database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db")) mysql_driver: str = "pymysql" mysql_args: list[str] = ["charset=utf8mb4"] @property def uri(self) -> str: """ Construct a database connection string """ if self.type is DBType.sqlite: # SQLite backend return f"sqlite:///{self.database}" elif self.type is DBType.mysql: # MySQL backend if self.mysql_args: args_str = "?" + "&".join(self.mysql_args) else: args_str = "" return (f"mysql+{self.mysql_driver}://" f"{self.user}:{self.password}@{self.host}" f"/{self.database}{args_str}") return "" class JWTConfig(BaseModel): """ Configuration for JSON Web Tokens """ secret: str | None = None hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 @validator("secret") @classmethod def ensure_secret(cls, value: str | None) -> str: """ Generate a per-run secret if `None` was loaded from the config file """ if value is None: return token_urlsafe(128) return value async def create_token( self, username: str, expiry_minutes: int | None = None, ) -> str: """ Build and sign a JSON Web Token """ if expiry_minutes is None: expiry_minutes = self.expiry_minutes return jwt.encode( { "sub": username, "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), }, self.secret, algorithm=self.hash_algorithm, ) async def decode_token( self, token: str, ) -> str | None: """ Verify a JSON Web Token, then extract the username """ # decode JWT token try: payload = jwt.decode( token, self.secret, algorithms=[self.hash_algorithm], ) except JWTError: return None # check expiry expiry = payload.get("exp") if expiry is None: return None if datetime.fromtimestamp(expiry) < datetime.utcnow(): return None # get username return payload.get("sub") class LockableString(BaseModel): """ A string that can be (logically) locked with an attached bool """ value: str locked: bool class LockableCountry(LockableString): """ Like `LockableString`, but with a `value` constrained two characters """ value: constr(max_length=2) # type: ignore class ServerDN(BaseModel): """ This server's "distinguished name" """ country: LockableCountry state: LockableString city: LockableString organization: LockableString organizational_unit: LockableString email: LockableString common_name: str class KeyAlgorithm(Enum): """ Supported certificate signing algorithms """ rsa2048 = "rsa2048" rsa4096 = "rsa4096" secp256r1 = "secp256r1" secp384r1 = "secp384r1" ed25519 = "ed25519" class CryptoConfig(BaseModel): """ Configuration for cryptography """ # password hash algorithms schemes: list[str] = ["bcrypt"] # pki settings key_algorithm: KeyAlgorithm | None ca_password: str | None ca_expiry_days: int | None cert_expiry_days: int | None @property def context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", ) class Config(BaseModel): """ Configuration for `kiwi-vpn-api` """ # may include client-to-client, cipher etc. openvpn_extra_options: dict[str, Any] | None db: DBConfig jwt: JWTConfig crypto: CryptoConfig server_dn: ServerDN __singleton: Config | None = None @classmethod def load(cls) -> Config | None: """ Load configuration from config file """ if cls.__singleton is not None: return cls.__singleton try: with open(Settings._.config_file, "r") as config_file: cls.__singleton = Config.parse_obj(json.load(config_file)) return cls.__singleton except FileNotFoundError: return None @classmethod @property def _(cls) -> Config: """ Shorthand for load(), but config file must exist """ if (config := cls.load()) is None: raise FileNotFoundError(Settings._.config_file) return config def save(self) -> None: """ Save configuration to config file """ with open(Settings._.config_file, "w") as config_file: config_file.write(self.json(indent=2))