90 lines
2 KiB
Python
90 lines
2 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
|
|
from fastapi import Depends
|
|
from jose.constants import ALGORITHMS
|
|
from passlib.context import CryptContext
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from .db.models import ORMBaseModel
|
|
|
|
PRODUCTION_MODE = False
|
|
|
|
# to get a string like this run:
|
|
# openssl rand -hex 32
|
|
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
|
|
ALGORITHM = "HS256"
|
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
|
|
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
SESSION_LOCAL = None
|
|
|
|
|
|
class DBType(Enum):
|
|
sqlite = "sqlite"
|
|
mysql = "mysql"
|
|
|
|
|
|
class DBConfig(BaseModel):
|
|
db_type: DBType = DBType.sqlite
|
|
|
|
|
|
class JWTConfig(BaseModel):
|
|
secret: str | None = None
|
|
hash_algorithm: str = ALGORITHMS.HS256
|
|
expiry_minutes: int = 30
|
|
|
|
|
|
class CryptoConfig(BaseModel):
|
|
schemes: list[str] = ["bcrypt"]
|
|
|
|
|
|
class BaseConfig(BaseModel):
|
|
db: DBConfig = Field(default_factory=DBConfig)
|
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
|
|
|
|
|
CONFIG_FILE = "tmp/config.json"
|
|
|
|
|
|
async def has_config() -> bool:
|
|
return Path(CONFIG_FILE).is_file()
|
|
|
|
|
|
async def load_config() -> BaseConfig:
|
|
try:
|
|
with open(CONFIG_FILE, "r") as kv:
|
|
return BaseConfig.parse_obj(json.load(kv))
|
|
|
|
except FileNotFoundError:
|
|
return BaseConfig()
|
|
|
|
|
|
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
|
|
global SESSION_LOCAL
|
|
|
|
engine = create_engine(
|
|
"sqlite:///./tmp/vpn.db",
|
|
connect_args={"check_same_thread": False},
|
|
)
|
|
SESSION_LOCAL = sessionmaker(
|
|
autocommit=False, autoflush=False, bind=engine)
|
|
|
|
ORMBaseModel.metadata.create_all(bind=engine)
|
|
|
|
|
|
async def get_db() -> Session:
|
|
global SESSION_LOCAL
|
|
|
|
db = SESSION_LOCAL()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|