kiwi-vpn/api/kiwi_vpn_api/config.py

135 lines
3.2 KiB
Python
Raw Normal View History

2022-03-18 22:43:02 +00:00
from __future__ import annotations
import functools
2022-03-18 18:22:17 +00:00
import json
2022-03-19 02:22:49 +00:00
from datetime import datetime, timedelta
2022-03-18 18:24:09 +00:00
from enum import Enum
2022-03-16 00:23:57 +00:00
2022-03-19 02:22:49 +00:00
from jose import JWTError, jwt
2022-03-16 00:23:57 +00:00
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
2022-03-18 22:43:02 +00:00
from pydantic import BaseModel, BaseSettings, Field
2022-03-18 18:22:17 +00:00
from sqlalchemy import create_engine
2022-03-18 18:24:09 +00:00
from sqlalchemy.engine import Engine
2022-03-16 00:23:57 +00:00
2022-03-18 22:43:02 +00:00
class Settings(BaseSettings):
production_mode: bool = False
config_file: str = "tmp/config.json"
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = None
@staticmethod
@functools.lru_cache
def get() -> Settings:
return Settings()
2022-03-17 22:47:31 +00:00
2022-03-16 00:23:57 +00:00
class DBType(Enum):
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
2022-03-16 00:23:57 +00:00
2022-03-18 18:22:17 +00:00
@property
def db_engine(self) -> Engine:
return create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
2022-03-16 00:23:57 +00:00
class JWTConfig(BaseModel):
2022-03-17 23:00:49 +00:00
secret: str | None = None
2022-03-16 00:23:57 +00:00
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
2022-03-19 02:22:49 +00:00
async def encode(
self,
username: str,
expiry_minutes: int | None = None,
) -> str:
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
return jwt.encode(
{
"sub": username,
"exp": datetime.utcnow() + timedelta(minutes=expiry_minutes),
},
self.secret,
algorithm=self.hash_algorithm,
)
async def decode(
self,
token: str,
) -> str | None:
# decode JWT token
try:
payload = jwt.decode(
token,
self.secret,
algorithms=[self.hash_algorithm],
)
except JWTError:
return None
# check expiry
expiry = payload.get("exp")
if expiry is None:
return None
if datetime.fromtimestamp(expiry) < datetime.utcnow():
return None
# get username
username = payload.get("sub")
if username is None:
return None
return username
2022-03-16 00:23:57 +00:00
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
2022-03-18 17:36:44 +00:00
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
2022-03-18 18:22:17 +00:00
schemes=self.schemes,
2022-03-18 17:36:44 +00:00
deprecated="auto",
)
2022-03-18 22:43:02 +00:00
class Config(BaseModel):
2022-03-18 18:22:17 +00:00
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
2022-03-18 17:36:44 +00:00
2022-03-18 22:43:02 +00:00
@staticmethod
async def get() -> Config | None:
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
except FileNotFoundError:
return None
@staticmethod
def set(config: Config) -> None:
with open(Settings.get().config_file, "w") as config_file:
config_file.write(config.json(indent=2))
2022-03-18 17:36:44 +00:00
@property
2022-03-18 18:22:17 +00:00
def crypt_context(self) -> CryptContext:
return self.crypto.crypt_context
2022-03-17 22:47:31 +00:00
2022-03-18 18:22:17 +00:00
@property
def db_engine(self) -> Engine:
return self.db.db_engine