diff --git a/api/kiwi_vpn_api/config.py b/api/kiwi_vpn_api/config.py index 8fc4da6..a8a2232 100644 --- a/api/kiwi_vpn_api/config.py +++ b/api/kiwi_vpn_api/config.py @@ -1,12 +1,18 @@ from __future__ import annotations +import json from enum import Enum +from pathlib import Path from typing import Optional +from fastapi import Depends from jose.constants import ALGORITHMS from passlib.context import CryptContext -from peewee import Database, SqliteDatabase from pydantic import BaseModel, Field +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker + +from .db.models import ORMBaseModel PRODUCTION_MODE = False @@ -18,6 +24,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30 CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") +SESSION_LOCAL = None + class DBType(Enum): sqlite = "sqlite" @@ -27,11 +35,6 @@ class DBType(Enum): class DBConfig(BaseModel): db_type: DBType = DBType.sqlite - @property - async def database(self) -> Database: - if self.db_type == DBType.sqlite: - return SqliteDatabase("tmp/vpn.db") - class JWTConfig(BaseModel): secret: Optional[str] = None @@ -42,12 +45,47 @@ class JWTConfig(BaseModel): class CryptoConfig(BaseModel): schemes: list[str] = ["bcrypt"] - @property - async def cryptContext(self) -> CryptContext: - return CryptContext(schemes=self.schemes, deprecated="auto") - class BaseConfig(BaseModel): db: DBConfig = Field(default_factory=DBConfig) jwt: JWTConfig = Field(default_factory=JWTConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig) + + +CONFIG_FILE = "tmp/config.json" + + +async def has_config() -> bool: + return Path(CONFIG_FILE).is_file() + + +async def load_config() -> BaseConfig: + try: + with open(CONFIG_FILE, "r") as kv: + return BaseConfig.parse_obj(json.load(kv)) + + except FileNotFoundError: + return BaseConfig() + + +async def connect_db(config: BaseConfig = Depends(load_config)) -> None: + global SESSION_LOCAL + + engine = create_engine( + "sqlite:///./tmp/vpn.db", + connect_args={"check_same_thread": False}, + ) + SESSION_LOCAL = sessionmaker( + autocommit=False, autoflush=False, bind=engine) + + ORMBaseModel.metadata.create_all(bind=engine) + + +async def get_db() -> Session: + global SESSION_LOCAL + + db = SESSION_LOCAL() + try: + yield db + finally: + db.close() diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py deleted file mode 100644 index 32afe34..0000000 --- a/api/kiwi_vpn_api/db/connection.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -SQLALCHEMY_DATABASE_URL = "sqlite:///./tmp/vpn.db" -# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db" - -engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} -) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - -ORMBaseModel = declarative_base() - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/api/kiwi_vpn_api/db/models.py b/api/kiwi_vpn_api/db/models.py index db1be18..dc54cbd 100644 --- a/api/kiwi_vpn_api/db/models.py +++ b/api/kiwi_vpn_api/db/models.py @@ -2,9 +2,10 @@ import datetime from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, UniqueConstraint) +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship -from .connection import ORMBaseModel +ORMBaseModel = declarative_base() class User(ORMBaseModel): diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 16c46a8..d1c3626 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -4,11 +4,8 @@ import uvicorn from fastapi import FastAPI from .config import PRODUCTION_MODE -from .db import connection, models from .routers import install -models.ORMBaseModel.metadata.create_all(bind=connection.engine) - api = FastAPI( title="kiwi-vpn API", description="This API enables the `kiwi-vpn` service.", diff --git a/api/kiwi_vpn_api/routers/install.py b/api/kiwi_vpn_api/routers/install.py index 2c9778c..7f9c22f 100644 --- a/api/kiwi_vpn_api/routers/install.py +++ b/api/kiwi_vpn_api/routers/install.py @@ -1,42 +1,15 @@ -import json -from pathlib import Path from secrets import token_hex from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from ..config import BaseConfig -from ..db import connection, crud, schemas +from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config, + load_config) +from ..db import crud, schemas router = APIRouter(prefix="/install") -CONFIG_FILE = "tmp/config.json" - - -async def has_config() -> bool: - return Path(CONFIG_FILE).is_file() - - -async def load_config() -> BaseConfig: - try: - with open(CONFIG_FILE, "r") as kv: - return BaseConfig.parse_obj(json.load(kv)) - - except FileNotFoundError: - return BaseConfig() - - -# async def connect_db(config: BaseConfig = Depends(load_config)) -> Database: -# db = await config.db.database -# db.connect() -# return db - - -# async def has_tables(db: Database = Depends(connect_db)) -> bool: -# return db.table_exists(User) - - @router.get( "/config", response_model=BaseConfig, @@ -79,7 +52,7 @@ async def set_config( if config.jwt.secret is None: config.jwt.secret = token_hex(32) - # DB.initialize(await connect_db(config)) + await connect_db(config) with open(CONFIG_FILE, "w") as kv: kv.write(config.json(indent=2)) @@ -90,9 +63,7 @@ async def set_config( "model": bool, }, }) -async def check_db( - db: Session = Depends(connection.get_db), -): +async def check_db(): return True @@ -111,8 +82,7 @@ async def check_db( async def create_db( admin_name: str, admin_password: str, - config: BaseConfig = Depends(load_config), - db: Session = Depends(connection.get_db), + db: Session = Depends(get_db), ): # if await has_tables(db): # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) @@ -120,6 +90,7 @@ async def create_db( # db.create_tables([Certificate, DistinguishedName, User, UserCapability]) # cryptContext = await config.crypto.cryptContext + crud.create_user(db, schemas.UserCreate( name=admin_name, password=admin_password,