231 lines
5.3 KiB
Python
231 lines
5.3 KiB
Python
"""
|
|
Configuration definition.
|
|
|
|
Converts per-run (environment) variables and config files into the
|
|
"python world" using `pydantic`.
|
|
|
|
Pydantic models might have convenience methods attached.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from secrets import token_urlsafe
|
|
|
|
from jose import JWTError, jwt
|
|
from jose.constants import ALGORITHMS
|
|
from passlib.context import CryptContext
|
|
from pydantic import BaseModel, BaseSettings, Field, validator
|
|
|
|
|
|
class Settings(BaseSettings):
|
|
"""
|
|
Per-run settings
|
|
"""
|
|
|
|
production_mode: bool = False
|
|
data_dir: Path = Path("./tmp")
|
|
openapi_url: str = "/openapi.json"
|
|
docs_url: str | None = "/docs"
|
|
redoc_url: str | None = "/redoc"
|
|
|
|
@staticmethod
|
|
@functools.lru_cache
|
|
def get() -> Settings:
|
|
return Settings()
|
|
|
|
@property
|
|
def config_file(self) -> Path:
|
|
return self.data_dir.joinpath("config.json")
|
|
|
|
|
|
class DBType(Enum):
|
|
"""
|
|
Supported database types
|
|
"""
|
|
|
|
sqlite = "sqlite"
|
|
mysql = "mysql"
|
|
|
|
|
|
class DBConfig(BaseModel):
|
|
"""
|
|
Database connection configuration
|
|
"""
|
|
|
|
type: DBType = DBType.sqlite
|
|
user: str | None = None
|
|
password: str | None = None
|
|
host: str | None = None
|
|
database: str | None = Settings.get().data_dir.joinpath("vpn.db")
|
|
|
|
mysql_driver: str = "pymysql"
|
|
mysql_args: list[str] = ["charset=utf8mb4"]
|
|
|
|
@property
|
|
def uri(self) -> str:
|
|
"""
|
|
Construct a database connection string
|
|
"""
|
|
|
|
if self.type is DBType.sqlite:
|
|
# SQLite backend
|
|
return f"sqlite:///{self.database}"
|
|
|
|
elif self.type is DBType.mysql:
|
|
# MySQL backend
|
|
if self.mysql_args:
|
|
args_str = "?" + "&".join(self.mysql_args)
|
|
else:
|
|
args_str = ""
|
|
|
|
return (f"mysql+{self.mysql_driver}://"
|
|
f"{self.user}:{self.password}@{self.host}"
|
|
f"/{self.database}{args_str}")
|
|
|
|
|
|
class JWTConfig(BaseModel):
|
|
"""
|
|
Configuration for JSON Web Tokens
|
|
"""
|
|
|
|
secret: str | None = None
|
|
hash_algorithm: str = ALGORITHMS.HS256
|
|
expiry_minutes: int = 30
|
|
|
|
@validator("secret")
|
|
@classmethod
|
|
def ensure_secret(cls, value: str | None) -> str:
|
|
"""
|
|
Generate a per-run secret if `None` was loaded from the config file
|
|
"""
|
|
|
|
if value is None:
|
|
return token_urlsafe(128)
|
|
|
|
return value
|
|
|
|
async def create_token(
|
|
self,
|
|
username: str,
|
|
expiry_minutes: int | None = None,
|
|
) -> str:
|
|
"""
|
|
Build and sign a JSON Web Token
|
|
"""
|
|
|
|
if expiry_minutes is None:
|
|
expiry_minutes = self.expiry_minutes
|
|
|
|
return jwt.encode(
|
|
{
|
|
"sub": username,
|
|
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
|
|
},
|
|
self.secret,
|
|
algorithm=self.hash_algorithm,
|
|
)
|
|
|
|
async def decode_token(
|
|
self,
|
|
token: str,
|
|
) -> str | None:
|
|
"""
|
|
Verify a JSON Web Token, then extract the username
|
|
"""
|
|
|
|
# decode JWT token
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
self.secret,
|
|
algorithms=[self.hash_algorithm],
|
|
)
|
|
|
|
except JWTError:
|
|
return None
|
|
|
|
# check expiry
|
|
expiry = payload.get("exp")
|
|
if expiry is None:
|
|
return None
|
|
|
|
if datetime.fromtimestamp(expiry) < datetime.utcnow():
|
|
return None
|
|
|
|
# get username
|
|
username = payload.get("sub")
|
|
if username is None:
|
|
return None
|
|
|
|
return username
|
|
|
|
|
|
class CryptoConfig(BaseModel):
|
|
"""
|
|
Configuration for hash algorithms
|
|
"""
|
|
|
|
schemes: list[str] = ["bcrypt"]
|
|
|
|
@property
|
|
def crypt_context_sync(self) -> CryptContext:
|
|
return CryptContext(
|
|
schemes=self.schemes,
|
|
deprecated="auto",
|
|
)
|
|
|
|
@property
|
|
async def crypt_context(self) -> CryptContext:
|
|
return CryptContext(
|
|
schemes=self.schemes,
|
|
deprecated="auto",
|
|
)
|
|
|
|
|
|
class Config(BaseModel):
|
|
"""
|
|
Configuration for `kiwi-vpn-api`
|
|
"""
|
|
|
|
db: DBConfig = Field(default_factory=DBConfig)
|
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
|
|
|
@staticmethod
|
|
def load_sync() -> Config | None:
|
|
"""
|
|
Load configuration from config file
|
|
"""
|
|
|
|
try:
|
|
with open(Settings.get().config_file, "r") as config_file:
|
|
return Config.parse_obj(json.load(config_file))
|
|
|
|
except FileNotFoundError:
|
|
return None
|
|
|
|
@staticmethod
|
|
async def load() -> Config | None:
|
|
"""
|
|
Load configuration from config file
|
|
"""
|
|
|
|
try:
|
|
with open(Settings.get().config_file, "r") as config_file:
|
|
return Config.parse_obj(json.load(config_file))
|
|
|
|
except FileNotFoundError:
|
|
return None
|
|
|
|
async def save(self) -> None:
|
|
"""
|
|
Save configuration to config file
|
|
"""
|
|
|
|
with open(Settings.get().config_file, "w") as config_file:
|
|
config_file.write(self.json(indent=2))
|