more SQL Alchemy
This commit is contained in:
parent
c778e2aa98
commit
0607e0383c
5 changed files with 57 additions and 71 deletions
|
@ -1,12 +1,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Depends
|
||||||
from jose.constants import ALGORITHMS
|
from jose.constants import ALGORITHMS
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from peewee import Database, SqliteDatabase
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from .db.models import ORMBaseModel
|
||||||
|
|
||||||
PRODUCTION_MODE = False
|
PRODUCTION_MODE = False
|
||||||
|
|
||||||
|
@ -18,6 +24,8 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
|
||||||
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
SESSION_LOCAL = None
|
||||||
|
|
||||||
|
|
||||||
class DBType(Enum):
|
class DBType(Enum):
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
|
@ -27,11 +35,6 @@ class DBType(Enum):
|
||||||
class DBConfig(BaseModel):
|
class DBConfig(BaseModel):
|
||||||
db_type: DBType = DBType.sqlite
|
db_type: DBType = DBType.sqlite
|
||||||
|
|
||||||
@property
|
|
||||||
async def database(self) -> Database:
|
|
||||||
if self.db_type == DBType.sqlite:
|
|
||||||
return SqliteDatabase("tmp/vpn.db")
|
|
||||||
|
|
||||||
|
|
||||||
class JWTConfig(BaseModel):
|
class JWTConfig(BaseModel):
|
||||||
secret: Optional[str] = None
|
secret: Optional[str] = None
|
||||||
|
@ -42,12 +45,47 @@ class JWTConfig(BaseModel):
|
||||||
class CryptoConfig(BaseModel):
|
class CryptoConfig(BaseModel):
|
||||||
schemes: list[str] = ["bcrypt"]
|
schemes: list[str] = ["bcrypt"]
|
||||||
|
|
||||||
@property
|
|
||||||
async def cryptContext(self) -> CryptContext:
|
|
||||||
return CryptContext(schemes=self.schemes, deprecated="auto")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
db: DBConfig = Field(default_factory=DBConfig)
|
db: DBConfig = Field(default_factory=DBConfig)
|
||||||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIG_FILE = "tmp/config.json"
|
||||||
|
|
||||||
|
|
||||||
|
async def has_config() -> bool:
|
||||||
|
return Path(CONFIG_FILE).is_file()
|
||||||
|
|
||||||
|
|
||||||
|
async def load_config() -> BaseConfig:
|
||||||
|
try:
|
||||||
|
with open(CONFIG_FILE, "r") as kv:
|
||||||
|
return BaseConfig.parse_obj(json.load(kv))
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
return BaseConfig()
|
||||||
|
|
||||||
|
|
||||||
|
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
|
||||||
|
global SESSION_LOCAL
|
||||||
|
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite:///./tmp/vpn.db",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
)
|
||||||
|
SESSION_LOCAL = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
ORMBaseModel.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_db() -> Session:
|
||||||
|
global SESSION_LOCAL
|
||||||
|
|
||||||
|
db = SESSION_LOCAL()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./tmp/vpn.db"
|
|
||||||
# SQLALCHEMY_DATABASE_URL = "postgresql://user:password@postgresserver/db"
|
|
||||||
|
|
||||||
engine = create_engine(
|
|
||||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
||||||
)
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
ORMBaseModel = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
|
||||||
db = SessionLocal()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
|
@ -2,9 +2,10 @@ import datetime
|
||||||
|
|
||||||
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
|
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
|
||||||
UniqueConstraint)
|
UniqueConstraint)
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from .connection import ORMBaseModel
|
ORMBaseModel = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
class User(ORMBaseModel):
|
class User(ORMBaseModel):
|
||||||
|
|
|
@ -4,11 +4,8 @@ import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from .config import PRODUCTION_MODE
|
from .config import PRODUCTION_MODE
|
||||||
from .db import connection, models
|
|
||||||
from .routers import install
|
from .routers import install
|
||||||
|
|
||||||
models.ORMBaseModel.metadata.create_all(bind=connection.engine)
|
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="kiwi-vpn API",
|
title="kiwi-vpn API",
|
||||||
description="This API enables the `kiwi-vpn` service.",
|
description="This API enables the `kiwi-vpn` service.",
|
||||||
|
|
|
@ -1,42 +1,15 @@
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from secrets import token_hex
|
from secrets import token_hex
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import BaseConfig
|
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
|
||||||
from ..db import connection, crud, schemas
|
load_config)
|
||||||
|
from ..db import crud, schemas
|
||||||
|
|
||||||
router = APIRouter(prefix="/install")
|
router = APIRouter(prefix="/install")
|
||||||
|
|
||||||
|
|
||||||
CONFIG_FILE = "tmp/config.json"
|
|
||||||
|
|
||||||
|
|
||||||
async def has_config() -> bool:
|
|
||||||
return Path(CONFIG_FILE).is_file()
|
|
||||||
|
|
||||||
|
|
||||||
async def load_config() -> BaseConfig:
|
|
||||||
try:
|
|
||||||
with open(CONFIG_FILE, "r") as kv:
|
|
||||||
return BaseConfig.parse_obj(json.load(kv))
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
return BaseConfig()
|
|
||||||
|
|
||||||
|
|
||||||
# async def connect_db(config: BaseConfig = Depends(load_config)) -> Database:
|
|
||||||
# db = await config.db.database
|
|
||||||
# db.connect()
|
|
||||||
# return db
|
|
||||||
|
|
||||||
|
|
||||||
# async def has_tables(db: Database = Depends(connect_db)) -> bool:
|
|
||||||
# return db.table_exists(User)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/config",
|
"/config",
|
||||||
response_model=BaseConfig,
|
response_model=BaseConfig,
|
||||||
|
@ -79,7 +52,7 @@ async def set_config(
|
||||||
if config.jwt.secret is None:
|
if config.jwt.secret is None:
|
||||||
config.jwt.secret = token_hex(32)
|
config.jwt.secret = token_hex(32)
|
||||||
|
|
||||||
# DB.initialize(await connect_db(config))
|
await connect_db(config)
|
||||||
|
|
||||||
with open(CONFIG_FILE, "w") as kv:
|
with open(CONFIG_FILE, "w") as kv:
|
||||||
kv.write(config.json(indent=2))
|
kv.write(config.json(indent=2))
|
||||||
|
@ -90,9 +63,7 @@ async def set_config(
|
||||||
"model": bool,
|
"model": bool,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
async def check_db(
|
async def check_db():
|
||||||
db: Session = Depends(connection.get_db),
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,8 +82,7 @@ async def check_db(
|
||||||
async def create_db(
|
async def create_db(
|
||||||
admin_name: str,
|
admin_name: str,
|
||||||
admin_password: str,
|
admin_password: str,
|
||||||
config: BaseConfig = Depends(load_config),
|
db: Session = Depends(get_db),
|
||||||
db: Session = Depends(connection.get_db),
|
|
||||||
):
|
):
|
||||||
# if await has_tables(db):
|
# if await has_tables(db):
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
@ -120,6 +90,7 @@ async def create_db(
|
||||||
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
|
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
|
||||||
|
|
||||||
# cryptContext = await config.crypto.cryptContext
|
# cryptContext = await config.crypto.cryptContext
|
||||||
|
|
||||||
crud.create_user(db, schemas.UserCreate(
|
crud.create_user(db, schemas.UserCreate(
|
||||||
name=admin_name,
|
name=admin_name,
|
||||||
password=admin_password,
|
password=admin_password,
|
||||||
|
|
Loading…
Reference in a new issue