kiwi-vpn/api/kiwi_vpn_api/config.py
2022-03-18 17:36:44 +00:00

105 lines
2.5 KiB
Python

from __future__ import annotations
import json
from enum import Enum
from pathlib import Path
from typing import Generator
from fastapi import Depends
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from pydantic import BaseModel, Field, PrivateAttr
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel
PRODUCTION_MODE = False
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
SESSION_LOCAL = None
class DBType(Enum):
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
class JWTConfig(BaseModel):
secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@property
def database(self) -> Session | None:
if self.__session_local is not None:
return self.__session_local()
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
async def get_db(
config: BaseConfig = Depends(load_config)
) -> Generator[Session | None, None, None]:
if db := config.database is None:
yield None
else:
try:
yield db
finally:
db.close()