adhere to a plan for a change :)
This commit is contained in:
parent
e34e669f79
commit
746ca51bdd
7 changed files with 148 additions and 233 deletions
|
@ -1,30 +1,14 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
import json
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
from fastapi import Depends
|
|
||||||
from jose.constants import ALGORITHMS
|
from jose.constants import ALGORITHMS
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
|
||||||
from .db.models import ORMBaseModel
|
CONFIG_FILE = "tmp/config.json"
|
||||||
|
|
||||||
PRODUCTION_MODE = False
|
|
||||||
|
|
||||||
# to get a string like this run:
|
|
||||||
# openssl rand -hex 32
|
|
||||||
SECRET_KEY = "2f7875b0d2be8a76eba8077ab4d9f8b1c749e02647e9ac9e0f909c3acbfc9856"
|
|
||||||
ALGORITHM = "HS256"
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
||||||
|
|
||||||
CRYPT_CONTEXT = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
||||||
|
|
||||||
SESSION_LOCAL = None
|
|
||||||
|
|
||||||
|
|
||||||
class DBType(Enum):
|
class DBType(Enum):
|
||||||
|
@ -35,6 +19,13 @@ class DBType(Enum):
|
||||||
class DBConfig(BaseModel):
|
class DBConfig(BaseModel):
|
||||||
db_type: DBType = DBType.sqlite
|
db_type: DBType = DBType.sqlite
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_engine(self) -> Engine:
|
||||||
|
return create_engine(
|
||||||
|
"sqlite:///./tmp/vpn.db",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JWTConfig(BaseModel):
|
class JWTConfig(BaseModel):
|
||||||
secret: str | None = None
|
secret: str | None = None
|
||||||
|
@ -45,6 +36,13 @@ class JWTConfig(BaseModel):
|
||||||
class CryptoConfig(BaseModel):
|
class CryptoConfig(BaseModel):
|
||||||
schemes: list[str] = ["bcrypt"]
|
schemes: list[str] = ["bcrypt"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def crypt_context(self) -> CryptContext:
|
||||||
|
return CryptContext(
|
||||||
|
schemes=self.schemes,
|
||||||
|
deprecated="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
db: DBConfig = Field(default_factory=DBConfig)
|
db: DBConfig = Field(default_factory=DBConfig)
|
||||||
|
@ -53,53 +51,22 @@ class BaseConfig(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def crypt_context(self) -> CryptContext:
|
def crypt_context(self) -> CryptContext:
|
||||||
return CryptContext(
|
return self.crypto.crypt_context
|
||||||
schemes=self.crypto.schemes,
|
|
||||||
deprecated="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
__session_local: sessionmaker = PrivateAttr()
|
|
||||||
|
|
||||||
async def connect_db(self) -> None:
|
|
||||||
engine = create_engine(
|
|
||||||
"sqlite:///./tmp/vpn.db",
|
|
||||||
connect_args={"check_same_thread": False},
|
|
||||||
)
|
|
||||||
self.__session_local = sessionmaker(
|
|
||||||
autocommit=False, autoflush=False, bind=engine,
|
|
||||||
)
|
|
||||||
ORMBaseModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database(self) -> Session | None:
|
def db_engine(self) -> Engine:
|
||||||
if self.__session_local is not None:
|
return self.db.db_engine
|
||||||
return self.__session_local()
|
|
||||||
|
|
||||||
|
|
||||||
CONFIG_FILE = "tmp/config.json"
|
async def get() -> BaseConfig | None:
|
||||||
|
|
||||||
|
|
||||||
async def has_config() -> bool:
|
|
||||||
return Path(CONFIG_FILE).is_file()
|
|
||||||
|
|
||||||
|
|
||||||
async def load_config() -> BaseConfig:
|
|
||||||
try:
|
try:
|
||||||
with open(CONFIG_FILE, "r") as kv:
|
with open(CONFIG_FILE, "r") as config_file:
|
||||||
return BaseConfig.parse_obj(json.load(kv))
|
return BaseConfig.parse_obj(json.load(config_file))
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return BaseConfig()
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_db(
|
def set(config: BaseConfig) -> None:
|
||||||
config: BaseConfig = Depends(load_config)
|
with open(CONFIG_FILE, "w") as config_file:
|
||||||
) -> Generator[Session | None, None, None]:
|
config_file.write(config.json(indent=2))
|
||||||
if db := config.database is None:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
30
api/kiwi_vpn_api/db/connection.py
Normal file
30
api/kiwi_vpn_api/db/connection.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
from typing import Generator
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
|
||||||
|
from .models import ORMBaseModel
|
||||||
|
|
||||||
|
ENGINE: Engine | None = None
|
||||||
|
SESSION_LOCAL: sessionmaker | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def reconnect(engine: Engine) -> None:
|
||||||
|
global ENGINE, SESSION_LOCAL
|
||||||
|
|
||||||
|
ENGINE = engine
|
||||||
|
SESSION_LOCAL = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=engine,
|
||||||
|
)
|
||||||
|
ORMBaseModel.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
|
||||||
|
async def get() -> Generator[Session | None, None, None]:
|
||||||
|
if SESSION_LOCAL is None:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
else:
|
||||||
|
db = SESSION_LOCAL()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
|
@ -1,11 +1,14 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Depends, FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from . import config
|
||||||
|
from .routers import admin
|
||||||
|
from .db import connection
|
||||||
|
|
||||||
|
PRODUCTION_MODE = False
|
||||||
|
|
||||||
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
|
|
||||||
load_config)
|
|
||||||
from .routers import install
|
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
title="kiwi-vpn API",
|
title="kiwi-vpn API",
|
||||||
|
@ -22,19 +25,20 @@ api = FastAPI(
|
||||||
redoc_url="/redoc" if not PRODUCTION_MODE else None,
|
redoc_url="/redoc" if not PRODUCTION_MODE else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
api.include_router(install.router)
|
|
||||||
# api.include_router(auth.router)
|
|
||||||
# api.include_router(user.router)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.mount("/api", api)
|
app.mount("/api", api)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def on_startup():
|
async def on_startup():
|
||||||
if await has_config():
|
# always include admin router
|
||||||
config = await load_config()
|
api.include_router(admin.router)
|
||||||
await config.connect_db()
|
|
||||||
|
if (current_config := await config.get()) is not None:
|
||||||
|
connection.reconnect(current_config.db_engine)
|
||||||
|
|
||||||
|
# include other routers
|
||||||
|
# api.include_router(auth.router)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -26,7 +26,11 @@ else:
|
||||||
|
|
||||||
# POST admin/user
|
# POST admin/user
|
||||||
|
|
||||||
if user table is empty:
|
if no config file present:
|
||||||
|
|
||||||
|
- error
|
||||||
|
|
||||||
|
elif user table is empty:
|
||||||
|
|
||||||
- create new user
|
- create new user
|
||||||
- give "admin" cap to new user
|
- give "admin" cap to new user
|
||||||
|
|
69
api/kiwi_vpn_api/routers/admin.py
Normal file
69
api/kiwi_vpn_api/routers/admin.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
from secrets import token_hex
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from .. import config
|
||||||
|
|
||||||
|
from ..db import crud, schemas, connection
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/config",
|
||||||
|
responses={
|
||||||
|
status.HTTP_200_OK: {
|
||||||
|
"content": None,
|
||||||
|
},
|
||||||
|
status.HTTP_403_FORBIDDEN: {
|
||||||
|
"description": "Must be admin",
|
||||||
|
"content": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def set_config(
|
||||||
|
new_config: config.BaseConfig,
|
||||||
|
current_config: config.BaseConfig | None = Depends(config.get),
|
||||||
|
):
|
||||||
|
if current_config is not None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
if new_config.jwt.secret is None:
|
||||||
|
new_config.jwt.secret = token_hex(32)
|
||||||
|
|
||||||
|
connection.reconnect(new_config.db_engine)
|
||||||
|
|
||||||
|
config.set(new_config)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/user",
|
||||||
|
responses={
|
||||||
|
status.HTTP_200_OK: {
|
||||||
|
"content": None,
|
||||||
|
},
|
||||||
|
status.HTTP_400_BAD_REQUEST: {
|
||||||
|
"description": "Database doesn't exist",
|
||||||
|
"content": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def add_user(
|
||||||
|
user_name: str,
|
||||||
|
user_password: str,
|
||||||
|
current_config: config.BaseConfig | None = Depends(config.get),
|
||||||
|
db: Session | None = Depends(connection.get),
|
||||||
|
):
|
||||||
|
if current_config is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
crud.create_user(
|
||||||
|
db=db,
|
||||||
|
user=schemas.UserCreate(
|
||||||
|
name=user_name,
|
||||||
|
password=user_password,
|
||||||
|
),
|
||||||
|
crypt_context=current_config.crypt_context,
|
||||||
|
)
|
||||||
|
crud.add_user_capability(db, user_name=user_name, capability="admin")
|
|
@ -1,101 +0,0 @@
|
||||||
from secrets import token_hex
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
|
|
||||||
load_config)
|
|
||||||
from ..db import crud, schemas
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/install")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/config",
|
|
||||||
response_model=BaseConfig,
|
|
||||||
responses={
|
|
||||||
status.HTTP_403_FORBIDDEN: {
|
|
||||||
"description": "Must be admin",
|
|
||||||
"content": None,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_config(
|
|
||||||
config: BaseConfig = Depends(load_config),
|
|
||||||
has_config: bool = Depends(has_config),
|
|
||||||
):
|
|
||||||
if has_config:
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
|
||||||
"/config",
|
|
||||||
responses={
|
|
||||||
status.HTTP_200_OK: {
|
|
||||||
"content": None,
|
|
||||||
},
|
|
||||||
status.HTTP_403_FORBIDDEN: {
|
|
||||||
"description": "Must be admin",
|
|
||||||
"content": None,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def set_config(
|
|
||||||
config: BaseConfig,
|
|
||||||
has_config: bool = Depends(has_config),
|
|
||||||
):
|
|
||||||
if has_config:
|
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
||||||
|
|
||||||
if config.jwt.secret is None:
|
|
||||||
config.jwt.secret = token_hex(32)
|
|
||||||
|
|
||||||
await config.connect_db()
|
|
||||||
|
|
||||||
with open(CONFIG_FILE, "w") as kv:
|
|
||||||
kv.write(config.json(indent=2))
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/db", responses={
|
|
||||||
status.HTTP_200_OK: {
|
|
||||||
"model": bool,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
async def check_db():
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
|
||||||
"/db",
|
|
||||||
responses={
|
|
||||||
status.HTTP_200_OK: {
|
|
||||||
"content": None,
|
|
||||||
},
|
|
||||||
status.HTTP_400_BAD_REQUEST: {
|
|
||||||
"description": "Database exists",
|
|
||||||
"content": None,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def create_db(
|
|
||||||
admin_name: str,
|
|
||||||
admin_password: str,
|
|
||||||
config: BaseConfig = Depends(load_config),
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
):
|
|
||||||
# if await has_tables(db):
|
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
|
||||||
|
|
||||||
# cryptContext = await config.crypto.cryptContext
|
|
||||||
|
|
||||||
if db is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
|
||||||
|
|
||||||
crud.create_user(db, schemas.UserCreate(
|
|
||||||
name=admin_name,
|
|
||||||
password=admin_password,
|
|
||||||
crypt_context=config.crypt_context,
|
|
||||||
))
|
|
||||||
crud.add_user_capability(db, user_name=admin_name, capability="admin")
|
|
|
@ -1,58 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..config import ALGORITHM, SECRET_KEY
|
|
||||||
from ..db import User as db_User
|
|
||||||
from .auth import SCHEME
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/user")
|
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
|
||||||
name: str
|
|
||||||
capabilities: list[str]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db(cls, username: str) -> User | None:
|
|
||||||
user = db_User.get_by_name(username)
|
|
||||||
|
|
||||||
if not user:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name=user.name,
|
|
||||||
capabilities=[cap.capability for cap in user.capabilities],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current", response_model=User)
|
|
||||||
async def get_current_user(token: str = Depends(SCHEME)):
|
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
|
||||||
username: str = payload.get("sub")
|
|
||||||
if username is None:
|
|
||||||
raise credentials_exception
|
|
||||||
except JWTError:
|
|
||||||
raise credentials_exception
|
|
||||||
user = User.from_db(username)
|
|
||||||
if user is None:
|
|
||||||
raise credentials_exception
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def is_admin(current_user: User = Depends(get_current_user)):
|
|
||||||
return ("admin" in current_user.capabilities)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current/is_admin")
|
|
||||||
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
|
|
||||||
return {"is_admin": is_admin}
|
|
Loading…
Reference in a new issue