more SQL Alchemy

This commit is contained in:
Jörn-Michael Miehe 2022-03-17 22:47:31 +00:00
parent c778e2aa98
commit 0607e0383c
5 changed files with 57 additions and 71 deletions

View file

@ -1,12 +1,18 @@
from __future__ import annotations
import json
from enum import Enum
from pathlib import Path
from typing import Optional
from fastapi import Depends
from jose.constants import ALGORITHMS
from passlib.context import CryptContext
from peewee import Database, SqliteDatabase
from pydantic import BaseModel, Field
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from .db.models import ORMBaseModel
PRODUCTION_MODE = False
@ -18,6 +24,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
SESSION_LOCAL = None
class DBType(Enum):
sqlite = "sqlite"
@ -27,11 +35,6 @@ class DBType(Enum):
class DBConfig(BaseModel):
db_type: DBType = DBType.sqlite
@property
async def database(self) -> Database:
if self.db_type == DBType.sqlite:
return SqliteDatabase("tmp/vpn.db")
class JWTConfig(BaseModel):
secret: Optional[str] = None
@ -42,12 +45,47 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"]
@property
async def cryptContext(self) -> CryptContext:
return CryptContext(schemes=self.schemes, deprecated="auto")
class BaseConfig(BaseModel):
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
global SESSION_LOCAL
engine = create_engine(
"sqlite:///./tmp/vpn.db",
connect_args={"check_same_thread": False},
)
SESSION_LOCAL = sessionmaker(
autocommit=False, autoflush=False, bind=engine)
ORMBaseModel.metadata.create_all(bind=engine)
async def get_db() -> Session:
global SESSION_LOCAL
db = SESSION_LOCAL()
try:
yield db
finally:
db.close()

View file

@ -1,21 +0,0 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///./tmp/vpn.db"
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
ORMBaseModel = declarative_base()
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -2,9 +2,10 @@ import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from .connection import ORMBaseModel
ORMBaseModel = declarative_base()
class User(ORMBaseModel):

View file

@ -4,11 +4,8 @@ import uvicorn
from fastapi import FastAPI
from .config import PRODUCTION_MODE
from .db import connection, models
from .routers import install
models.ORMBaseModel.metadata.create_all(bind=connection.engine)
api = FastAPI(
title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.",

View file

@ -1,42 +1,15 @@
import json
from pathlib import Path
from secrets import token_hex
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import BaseConfig
from ..db import connection, crud, schemas
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
load_config)
from ..db import crud, schemas
router = APIRouter(prefix="/install")
CONFIG_FILE = "tmp/config.json"
async def has_config() -> bool:
return Path(CONFIG_FILE).is_file()
async def load_config() -> BaseConfig:
try:
with open(CONFIG_FILE, "r") as kv:
return BaseConfig.parse_obj(json.load(kv))
except FileNotFoundError:
return BaseConfig()
# async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
# db = await config.db.database
# db.connect()
# return db
# async def has_tables(db: Database = Depends(connect_db)) -> bool:
# return db.table_exists(User)
@router.get(
"/config",
response_model=BaseConfig,
@ -79,7 +52,7 @@ async def set_config(
if config.jwt.secret is None:
config.jwt.secret = token_hex(32)
# DB.initialize(await connect_db(config))
await connect_db(config)
with open(CONFIG_FILE, "w") as kv:
kv.write(config.json(indent=2))
@ -90,9 +63,7 @@ async def set_config(
"model": bool,
},
})
async def check_db(
db: Session = Depends(connection.get_db),
):
async def check_db():
return True
@ -111,8 +82,7 @@ async def check_db(
async def create_db(
admin_name: str,
admin_password: str,
config: BaseConfig = Depends(load_config),
db: Session = Depends(connection.get_db),
db: Session = Depends(get_db),
):
# if await has_tables(db):
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
@ -120,6 +90,7 @@ async def create_db(
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
# cryptContext = await config.crypto.cryptContext
crud.create_user(db, schemas.UserCreate(
name=admin_name,
password=admin_password,