kiwi-vpn/api/kiwi_vpn_api/config.py

219 lines
5 KiB
Python

"""
Configuration definition.
Converts per-run (environment) variables and config files into the
"python world" using `pydantic`.
Pydantic models might have convenience methods attached.
"""
from __future__ import annotations
import functools
import json
from datetime import datetime, timedelta
from enum import Enum
from pathlib import Path
from secrets import token_urlsafe
from jose import JWTError, jwt
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field, validator
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
class Settings(BaseSettings):
"""
Per-run settings
"""
production_mode: bool = False
data_dir: Path = Path("./tmp")
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc"
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
@property
def config_file(self) -> Path:
return self.data_dir.joinpath("config.json")
class DBType(Enum):
"""
Supported database types
"""
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
"""
Database connection configuration
"""
type: DBType = DBType.sqlite
user: str | None = None
password: str | None = None
host: str | None = None
database: str | None = Settings.get().data_dir.joinpath("vpn.db")
mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"]
@property
async def db_engine(self) -> Engine:
"""
Construct an SQLAlchemy engine
"""
if self.type is DBType.sqlite:
# SQLite backend
return create_engine(
f"sqlite:///{self.database}",
connect_args={"check_same_thread": False},
)
elif self.type is DBType.mysql:
# MySQL backend
if self.mysql_args:
args_str = "?" + "&".join(self.mysql_args)
else:
args_str = ""
return create_engine(
f"mysql+{self.mysql_driver}://"
f"{self.user}:{self.password}@{self.host}"
f"/{self.database}{args_str}",
pool_recycle=3600,
)
class JWTConfig(BaseModel):
"""
Configuration for JSON Web Tokens
"""
secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
@validator("secret")
@classmethod
def ensure_secret(cls, value: str | None) -> str:
"""
Generate a per-run secret if `None` was loaded from the config file
"""
if value is None:
return token_urlsafe(128)
return value
async def create_token(
self,
username: str,
expiry_minutes: int | None = None,
) -> str:
"""
Build and sign a JSON Web Token
"""
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
return jwt.encode(
{
"sub": username,
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
},
self.secret,
algorithm=self.hash_algorithm,
)
async def decode_token(
self,
token: str,
) -> str | None:
"""
Verify a JSON Web Token, then extract the username
"""
# decode JWT token
try:
payload = jwt.decode(
token,
self.secret,
algorithms=[self.hash_algorithm],
)
except JWTError:
return None
# check expiry
expiry = payload.get("exp")
if expiry is None:
return None
if datetime.fromtimestamp(expiry) < datetime.utcnow():
return None
# get username
username = payload.get("sub")
if username is None:
return None
return username
class CryptoConfig(BaseModel):
"""
Configuration for hash algorithms
"""
schemes: list[str] = ["bcrypt"]
@property
async def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.schemes,
deprecated="auto",
)
class Config(BaseModel):
"""
Configuration for `kiwi-vpn-api`
"""
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def load() -> Config | None:
"""
Load configuration from config file
"""
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
async def save(self) -> None:
"""
Save configuration to config file
"""
with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2))