more SQL Alchemy

This commit is contained in:
Jörn-Michael Miehe 2022-03-17 22:47:31 +00:00
parent c778e2aa98
commit 0607e0383c
5 changed files with 57 additions and 71 deletions

View file

@ -1,12 +1,18 @@
from __future__ import annotations from __future__ import annotations
import json
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Optional from typing import Optional
from fastapi import Depends
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from peewee import Database, SqliteDatabase
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel
PRODUCTION_MODE = False PRODUCTION_MODE = False
@ -18,6 +24,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto") CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
SESSION_LOCAL = None
class DBType(Enum): class DBType(Enum):
sqlite = "sqlite" sqlite = "sqlite"
@ -27,11 +35,6 @@ class DBType(Enum):
class DBConfig(BaseModel): class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite db_type: DBType = DBType.sqlite
@property
async def database(self) -> Database:
if self.db_type == DBType.sqlite:
return SqliteDatabase("tmp/vpn.db")
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
secret: Optional[str] = None secret: Optional[str] = None
@ -42,12 +45,47 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel): class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]
@property
async def cryptContext(self) -> CryptContext:
return CryptContext(schemes=self.schemes, deprecated="auto")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig) db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
global SESSION_LOCAL
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
SESSION_LOCAL = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
ORMBaseModel.metadata.create_all(bind=engine)
async def get_db() -> Session:
global SESSION_LOCAL
db = SESSION_LOCAL()
try:
yield db
finally:
db.close()

View file

@ -1,21 +0,0 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./tmp/vpn.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
ORMBaseModel = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -2,9 +2,10 @@ import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint) UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from .connection import ORMBaseModel ORMBaseModel = declarative_base()
class User(ORMBaseModel): class User(ORMBaseModel):

View file

@ -4,11 +4,8 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import PRODUCTION_MODE from .config import PRODUCTION_MODE
from .db import connection, models
from .routers import install from .routers import install
models.ORMBaseModel.metadata.create_all(bind=connection.engine)
api = FastAPI( api = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",

View file

@ -1,42 +1,15 @@
import json
from pathlib import Path
from secrets import token_hex from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import BaseConfig from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
from ..db import connection, crud, schemas load_config)
from ..db import crud, schemas
router = APIRouter(prefix="/install") router = APIRouter(prefix="/install")
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
# async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
# db = await config.db.database
# db.connect()
# return db
# async def has_tables(db: Database = Depends(connect_db)) -> bool:
# return db.table_exists(User)
@router.get( @router.get(
"/config", "/config",
response_model=BaseConfig, response_model=BaseConfig,
@ -79,7 +52,7 @@ async def set_config(
if config.jwt.secret is None: if config.jwt.secret is None:
config.jwt.secret = token_hex(32) config.jwt.secret = token_hex(32)
# DB.initialize(await connect_db(config)) await connect_db(config)
with open(CONFIG_FILE, "w") as kv: with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2)) kv.write(config.json(indent=2))
@ -90,9 +63,7 @@ async def set_config(
"model": bool, "model": bool,
}, },
}) })
async def check_db( async def check_db():
db: Session = Depends(connection.get_db),
):
return True return True
@ -111,8 +82,7 @@ async def check_db(
async def create_db( async def create_db(
admin_name: str, admin_name: str,
admin_password: str, admin_password: str,
config: BaseConfig = Depends(load_config), db: Session = Depends(get_db),
db: Session = Depends(connection.get_db),
): ):
# if await has_tables(db): # if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) # raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
@ -120,6 +90,7 @@ async def create_db(
# db.create_tables([Certificate, DistinguishedName, User, UserCapability]) # db.create_tables([Certificate, DistinguishedName, User, UserCapability])
# cryptContext = await config.crypto.cryptContext # cryptContext = await config.crypto.cryptContext
crud.create_user(db, schemas.UserCreate( crud.create_user(db, schemas.UserCreate(
name=admin_name, name=admin_name,
password=admin_password, password=admin_password,