from __future__ import annotations import functools import json from datetime import datetime, timedelta from enum import Enum from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field from sqlalchemy import create_engine from sqlalchemy.engine import Engine class Settings(BaseSettings): production_mode: bool = False config_file: str = "tmp/config.json" openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = "/redoc" @staticmethod @functools.lru_cache def get() -> Settings: return Settings() class DBType(Enum): sqlite = "sqlite" mysql = "mysql" class DBConfig(BaseModel): type: DBType = DBType.sqlite user: str | None = None password: str | None = None host: str | None = None database: str | None = "./tmp/vpn.db" mysql_driver: str = "pymysql" mysql_args: list[str] = ["charset=utf8mb4"] @property async def db_engine(self) -> Engine: if self.type is DBType.sqlite: # SQLite backend return create_engine( f"sqlite:///{self.database}", connect_args={"check_same_thread": False}, ) elif self.type is DBType.mysql: # MySQL backend if self.mysql_args: args_str = "?" + "&".join(self.mysql_args) else: args_str = "" return create_engine( f"mysql+{self.mysql_driver}://" f"{self.user}:{self.password}@{self.host}" f"/{self.database}{args_str}", pool_recycle=3600, ) class JWTConfig(BaseModel): secret: str | None = None hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 async def create_token( self, username: str, expiry_minutes: int | None = None, ) -> str: if expiry_minutes is None: expiry_minutes = self.expiry_minutes return jwt.encode( { "sub": username, "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), }, self.secret, algorithm=self.hash_algorithm, ) async def decode_token( self, token: str, ) -> str | None: # decode JWT token try: payload = jwt.decode( token, self.secret, algorithms=[self.hash_algorithm], ) except JWTError: return None # check expiry expiry = payload.get("exp") if expiry is None: return None if datetime.fromtimestamp(expiry) < datetime.utcnow(): return None # get username username = payload.get("sub") if username is None: return None return username class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] @property async def crypt_context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", ) class Config(BaseModel): db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) @staticmethod async def load() -> Config | None: try: with open(Settings.get().config_file, "r") as config_file: return Config.parse_obj(json.load(config_file)) except FileNotFoundError: return None async def save(self) -> None: with open(Settings.get().config_file, "w") as config_file: config_file.write(self.json(indent=2))