""" Configuration definition. Converts per-run (environment) variables and config files into the "python world" using `pydantic`. Pydantic models might have convenience methods attached. """ from __future__ import annotations import functools import json from datetime import datetime, timedelta from enum import Enum from pathlib import Path from secrets import token_urlsafe from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field, validator from sqlalchemy import create_engine from sqlalchemy.engine import Engine class Settings(BaseSettings): """ Per-run settings """ production_mode: bool = False data_dir: Path = Path("./tmp") openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = "/redoc" @staticmethod @functools.lru_cache def get() -> Settings: return Settings() @property def config_file(self) -> Path: return self.data_dir.joinpath("config.json") class DBType(Enum): """ Supported database types """ sqlite = "sqlite" mysql = "mysql" class DBConfig(BaseModel): """ Database connection configuration """ type: DBType = DBType.sqlite user: str | None = None password: str | None = None host: str | None = None database: str | None = Settings.get().data_dir.joinpath("vpn.db") mysql_driver: str = "pymysql" mysql_args: list[str] = ["charset=utf8mb4"] @property async def db_engine(self) -> Engine: """ Construct an SQLAlchemy engine """ if self.type is DBType.sqlite: # SQLite backend return create_engine( f"sqlite:///{self.database}", connect_args={"check_same_thread": False}, ) elif self.type is DBType.mysql: # MySQL backend if self.mysql_args: args_str = "?" + "&".join(self.mysql_args) else: args_str = "" return create_engine( f"mysql+{self.mysql_driver}://" f"{self.user}:{self.password}@{self.host}" f"/{self.database}{args_str}", pool_recycle=3600, ) class JWTConfig(BaseModel): """ Configuration for JSON Web Tokens """ secret: str | None = None hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 @validator("secret") @classmethod def ensure_secret(cls, value: str | None) -> str: """ Generate a per-run secret if `None` was loaded from the config file """ if value is None: return token_urlsafe(128) return value async def create_token( self, username: str, expiry_minutes: int | None = None, ) -> str: """ Build and sign a JSON Web Token """ if expiry_minutes is None: expiry_minutes = self.expiry_minutes return jwt.encode( { "sub": username, "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), }, self.secret, algorithm=self.hash_algorithm, ) async def decode_token( self, token: str, ) -> str | None: """ Verify a JSON Web Token, then extract the username """ # decode JWT token try: payload = jwt.decode( token, self.secret, algorithms=[self.hash_algorithm], ) except JWTError: return None # check expiry expiry = payload.get("exp") if expiry is None: return None if datetime.fromtimestamp(expiry) < datetime.utcnow(): return None # get username username = payload.get("sub") if username is None: return None return username class CryptoConfig(BaseModel): """ Configuration for hash algorithms """ schemes: list[str] = ["bcrypt"] @property async def crypt_context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", ) class Config(BaseModel): """ Configuration for `kiwi-vpn-api` """ db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) @staticmethod async def load() -> Config | None: """ Load configuration from config file """ try: with open(Settings.get().config_file, "r") as config_file: return Config.parse_obj(json.load(config_file)) except FileNotFoundError: return None async def save(self) -> None: """ Save configuration to config file """ with open(Settings.get().config_file, "w") as config_file: config_file.write(self.json(indent=2))