kiwi-vpn/api/kiwi_vpn_api/config.py
2022-03-19 02:28:18 +00:00

125 lines
3 KiB
Python

from __future__ import annotations
import functools
import json
from datetime import datetime, timedelta
from enum import Enum
from jose import JWTError, jwt
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
class Settings(BaseSettings):
production_mode: bool = False
config_file: str = "tmp/config.json"
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc"
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
class DBType(Enum):
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
@property
async def db_engine(self) -> Engine:
return create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
class JWTConfig(BaseModel):
secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
async def encode(
self,
username: str,
expiry_minutes: int | None = None,
) -> str:
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
return jwt.encode(
{
"sub": username,
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
},
self.secret,
algorithm=self.hash_algorithm,
)
async def decode(
self,
token: str,
) -> str | None:
# decode JWT token
try:
payload = jwt.decode(
token,
self.secret,
algorithms=[self.hash_algorithm],
)
except JWTError:
return None
# check expiry
expiry = payload.get("exp")
if expiry is None:
return None
if datetime.fromtimestamp(expiry) < datetime.utcnow():
return None
# get username
username = payload.get("sub")
if username is None:
return None
return username
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
@property
async def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.schemes,
deprecated="auto",
)
class Config(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def get() -> Config | None:
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
async def set(self) -> None:
with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2))