kiwi-vpn/api/kiwi_vpn_api/config.py

148 lines
3.8 KiB
Python

from __future__ import annotations
import functools
import json
from datetime import datetime, timedelta
from enum import Enum
from jose import JWTError, jwt
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
class Settings(BaseSettings):
production_mode: bool = False
config_file: str = "tmp/config.json"
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc"
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
class DBType(Enum):
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
type: DBType = DBType.sqlite
user: str | None = None
password: str | None = None
host: str | None = None
database: str | None = "./tmp/vpn.db"
mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"]
@property
async def db_engine(self) -> Engine:
if self.type is DBType.sqlite:
# SQLite backend
return create_engine(
f"sqlite:///{self.database}",
connect_args={"check_same_thread": False},
)
elif self.type is DBType.mysql:
# MySQL backend
if self.mysql_args:
args_str = "?" + "&".join(self.mysql_args)
else:
args_str = ""
return create_engine(
f"mysql+{self.mysql_driver}://"
f"{self.user}:{self.password}@{self.host}"
f"/{self.database}{args_str}",
pool_recycle=3600,
)
class JWTConfig(BaseModel):
secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
async def create_token(
self,
username: str,
expiry_minutes: int | None = None,
) -> str:
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
return jwt.encode(
{
"sub": username,
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
},
self.secret,
algorithm=self.hash_algorithm,
)
async def decode_token(
self,
token: str,
) -> str | None:
# decode JWT token
try:
payload = jwt.decode(
token,
self.secret,
algorithms=[self.hash_algorithm],
)
except JWTError:
return None
# check expiry
expiry = payload.get("exp")
if expiry is None:
return None
if datetime.fromtimestamp(expiry) < datetime.utcnow():
return None
# get username
username = payload.get("sub")
if username is None:
return None
return username
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
@property
async def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.schemes,
deprecated="auto",
)
class Config(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def load() -> Config | None:
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
async def save(self) -> None:
with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2))