from __future__ import annotations import functools import json from datetime import datetime, timedelta from enum import Enum from jose import JWTError, jwt from jose.constants import ALGORITHMS from passlib.context import CryptContext from pydantic import BaseModel, BaseSettings, Field from sqlalchemy import create_engine from sqlalchemy.engine import Engine class Settings(BaseSettings): production_mode: bool = False config_file: str = "tmp/config.json" openapi_url: str = "/openapi.json" docs_url: str | None = "/docs" redoc_url: str | None = None @staticmethod @functools.lru_cache def get() -> Settings: return Settings() class DBType(Enum): sqlite = "sqlite" mysql = "mysql" class DBConfig(BaseModel): db_type: DBType = DBType.sqlite @property def db_engine(self) -> Engine: return create_engine( "sqlite:///./tmp/vpn.db", connect_args={"check_same_thread": False}, ) class JWTConfig(BaseModel): secret: str | None = None hash_algorithm: str = ALGORITHMS.HS256 expiry_minutes: int = 30 async def encode( self, username: str, expiry_minutes: int | None = None, ) -> str: if expiry_minutes is None: expiry_minutes = self.expiry_minutes return jwt.encode( { "sub": username, "exp": datetime.utcnow() + timedelta(minutes=expiry_minutes), }, self.secret, algorithm=self.hash_algorithm, ) async def decode( self, token: str, ) -> str | None: # decode JWT token try: payload = jwt.decode( token, self.secret, algorithms=[self.hash_algorithm], ) except JWTError: return None # check expiry expiry = payload.get("exp") if expiry is None: return None if datetime.fromtimestamp(expiry) < datetime.utcnow(): return None # get username username = payload.get("sub") if username is None: return None return username class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] @property def crypt_context(self) -> CryptContext: return CryptContext( schemes=self.schemes, deprecated="auto", ) class Config(BaseModel): db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) @staticmethod async def get() -> Config | None: try: with open(Settings.get().config_file, "r") as config_file: return Config.parse_obj(json.load(config_file)) except FileNotFoundError: return None @staticmethod def set(config: Config) -> None: with open(Settings.get().config_file, "w") as config_file: config_file.write(config.json(indent=2)) @property def crypt_context(self) -> CryptContext: return self.crypto.crypt_context @property def db_engine(self) -> Engine: return self.db.db_engine