kiwi-vpn/api/kiwi_vpn_api/config.py

106 lines
2.5 KiB
Python
Raw Normal View History

2022-03-16 00:23:57 +00:00
from __future__ import annotations
2022-03-17 22:47:31 +00:00
import json
2022-03-16 00:23:57 +00:00
from enum import Enum
2022-03-17 22:47:31 +00:00
from pathlib import Path
2022-03-18 17:36:44 +00:00
from typing import Generator
2022-03-16 00:23:57 +00:00
2022-03-17 22:47:31 +00:00
from fastapi import Depends
2022-03-16 00:23:57 +00:00
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
2022-03-18 17:36:44 +00:00
from pydantic import BaseModel, Field, PrivateAttr
2022-03-17 22:47:31 +00:00
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel
PRODUCTION_MODE = False
# to get a string like this run:
# openssl rand -hex 32
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
2022-03-16 00:23:57 +00:00
2022-03-17 22:47:31 +00:00
SESSION_LOCAL = None
2022-03-16 00:23:57 +00:00
class DBType(Enum):
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
2022-03-16 00:23:57 +00:00
class JWTConfig(BaseModel):
2022-03-17 23:00:49 +00:00
secret: str | None = None
2022-03-16 00:23:57 +00:00
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
2022-03-17 22:47:31 +00:00
2022-03-18 17:36:44 +00:00
@property
def crypt_context(self) -> CryptContext:
return CryptContext(
schemes=self.crypto.schemes,
deprecated="auto",
)
__session_local: sessionmaker = PrivateAttr()
async def connect_db(self) -> None:
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
self.__session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@property
def database(self) -> Session | None:
if self.__session_local is not None:
return self.__session_local()
2022-03-17 22:47:31 +00:00
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
2022-03-18 17:36:44 +00:00
async def get_db(
config: BaseConfig = Depends(load_config)
) -> Generator[Session | None, None, None]:
if db := config.database is None:
yield None
2022-03-17 22:47:31 +00:00
2022-03-18 17:36:44 +00:00
else:
try:
yield db
finally:
db.close()