from __future__ import annotations import functools import json from datetime import datetime, timedelta from enum import Enum from secrets import token_hex from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field, validator from sqlalchemy import create_engine from sqlalchemy.engine import Engine class Settings(BaseSettings): production_mode: bool = False config_file: str = "tmp/config.json" openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = "/redoc" @staticmethod @functools.lru_cache def get() -> Settings: return Settings() class DBType(Enum): sqlite = "sqlite" mysql = "mysql" class DBConfig(BaseModel): type: DBType = DBType.sqlite user: str | None = None password: str | None = None host: str | None = None database: str | None = "./tmp/vpn.db" mysql_driver: str = "pymysql" mysql_args: list[str] = ["charset=utf8mb4"] @property async def db_engine(self) -> Engine: if self.type is DBType.sqlite: # SQLite backend return create_engine( f"sqlite:///{self.database}", connect_args={"check_same_thread": False}, ) elif self.type is DBType.mysql: # MySQL backend if self.mysql_args: args_str = "?" + "&".join(self.mysql_args) else: args_str = "" return create_engine( f"mysql+{self.mysql_driver}://" f"{self.user}:{self.password}@{self.host}" f"/{self.database}{args_str}", pool_recycle=3600, ) class JWTConfig(BaseModel): secret: str | None = None hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 @validator("secret") @classmethod def ensure_secret(cls, value: str | None) -> str: if value is None: return token_hex(32) return value async def create_token( self, username: str, expiry_minutes: int | None = None, ) -> str: if expiry_minutes is None: expiry_minutes = self.expiry_minutes return jwt.encode( { "sub": username, "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), }, self.secret, algorithm=self.hash_algorithm, ) async def decode_token( self, token: str, ) -> str | None: # decode JWT token try: payload = jwt.decode( token, self.secret, algorithms=[self.hash_algorithm], ) except JWTError: return None # check expiry expiry = payload.get("exp") if expiry is None: return None if datetime.fromtimestamp(expiry) < datetime.utcnow(): return None # get username username = payload.get("sub") if username is None: return None return username class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] @property async def crypt_context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", ) class Config(BaseModel): db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) @staticmethod async def load() -> Config | None: try: with open(Settings.get().config_file, "r") as config_file: return Config.parse_obj(json.load(config_file)) except FileNotFoundError: return None async def save(self) -> None: with open(Settings.get().config_file, "w") as config_file: config_file.write(self.json(indent=2))