134 lines
3.2 KiB
Python
134 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import functools
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
|
|
from jose import JWTError, jwt
|
|
from jose.constants import ALGORITHMS
|
|
from passlib.context import CryptContext
|
|
from pydantic import BaseModel, BaseSettings, Field
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine import Engine
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
production_mode: bool = False
|
|
config_file: str = "tmp/config.json"
|
|
openapi_url: str = "/openapi.json"
|
|
docs_url: str | None = "/docs"
|
|
redoc_url: str | None = None
|
|
|
|
@staticmethod
|
|
@functools.lru_cache
|
|
def get() -> Settings:
|
|
return Settings()
|
|
|
|
|
|
class DBType(Enum):
|
|
sqlite = "sqlite"
|
|
mysql = "mysql"
|
|
|
|
|
|
class DBConfig(BaseModel):
|
|
db_type: DBType = DBType.sqlite
|
|
|
|
@property
|
|
def db_engine(self) -> Engine:
|
|
return create_engine(
|
|
"sqlite:///./tmp/vpn.db",
|
|
connect_args={"check_same_thread": False},
|
|
)
|
|
|
|
|
|
class JWTConfig(BaseModel):
|
|
secret: str | None = None
|
|
hash_algorithm: str = ALGORITHMS.HS256
|
|
expiry_minutes: int = 30
|
|
|
|
async def encode(
|
|
self,
|
|
username: str,
|
|
expiry_minutes: int | None = None,
|
|
) -> str:
|
|
if expiry_minutes is None:
|
|
expiry_minutes = self.expiry_minutes
|
|
|
|
return jwt.encode(
|
|
{
|
|
"sub": username,
|
|
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
|
|
},
|
|
self.secret,
|
|
algorithm=self.hash_algorithm,
|
|
)
|
|
|
|
async def decode(
|
|
self,
|
|
token: str,
|
|
) -> str | None:
|
|
# decode JWT token
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
self.secret,
|
|
algorithms=[self.hash_algorithm],
|
|
)
|
|
|
|
except JWTError:
|
|
return None
|
|
|
|
# check expiry
|
|
expiry = payload.get("exp")
|
|
if expiry is None:
|
|
return None
|
|
|
|
if datetime.fromtimestamp(expiry) < datetime.utcnow():
|
|
return None
|
|
|
|
# get username
|
|
username = payload.get("sub")
|
|
if username is None:
|
|
return None
|
|
|
|
return username
|
|
|
|
|
|
class CryptoConfig(BaseModel):
|
|
schemes: list[str] = ["bcrypt"]
|
|
|
|
@property
|
|
def crypt_context(self) -> CryptContext:
|
|
return CryptContext(
|
|
schemes=self.schemes,
|
|
deprecated="auto",
|
|
)
|
|
|
|
|
|
class Config(BaseModel):
|
|
db: DBConfig = Field(default_factory=DBConfig)
|
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
|
|
|
@staticmethod
|
|
async def get() -> Config | None:
|
|
try:
|
|
with open(Settings.get().config_file, "r") as config_file:
|
|
return Config.parse_obj(json.load(config_file))
|
|
|
|
except FileNotFoundError:
|
|
return None
|
|
|
|
@staticmethod
|
|
def set(config: Config) -> None:
|
|
with open(Settings.get().config_file, "w") as config_file:
|
|
config_file.write(config.json(indent=2))
|
|
|
|
@property
|
|
def crypt_context(self) -> CryptContext:
|
|
return self.crypto.crypt_context
|
|
|
|
@property
|
|
def db_engine(self) -> Engine:
|
|
return self.db.db_engine
|