dirty commit
This commit is contained in:
parent
641dfd7ba0
commit
e34e669f79
6 changed files with 101 additions and 35 deletions
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
|||
import json
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
from fastapi import Depends
|
||||
from jose.constants import ALGORITHMS
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
|
@ -50,6 +51,30 @@ class BaseConfig(BaseModel):
|
|||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||
|
||||
@property
|
||||
def crypt_context(self) -> CryptContext:
|
||||
return CryptContext(
|
||||
schemes=self.crypto.schemes,
|
||||
deprecated="auto",
|
||||
)
|
||||
|
||||
__session_local: sessionmaker = PrivateAttr()
|
||||
|
||||
async def connect_db(self) -> None:
|
||||
engine = create_engine(
|
||||
"sqlite:///./tmp/vpn.db",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
self.__session_local = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine,
|
||||
)
|
||||
ORMBaseModel.metadata.create_all(bind=engine)
|
||||
|
||||
@property
|
||||
def database(self) -> Session | None:
|
||||
if self.__session_local is not None:
|
||||
return self.__session_local()
|
||||
|
||||
|
||||
CONFIG_FILE = "tmp/config.json"
|
||||
|
||||
|
@ -67,23 +92,13 @@ async def load_config() -> BaseConfig:
|
|||
return BaseConfig()
|
||||
|
||||
|
||||
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
|
||||
global SESSION_LOCAL
|
||||
async def get_db(
|
||||
config: BaseConfig = Depends(load_config)
|
||||
) -> Generator[Session | None, None, None]:
|
||||
if db := config.database is None:
|
||||
yield None
|
||||
|
||||
engine = create_engine(
|
||||
"sqlite:///./tmp/vpn.db",
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
SESSION_LOCAL = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
ORMBaseModel.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
async def get_db() -> Session:
|
||||
global SESSION_LOCAL
|
||||
|
||||
db = SESSION_LOCAL()
|
||||
else:
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from sqlalchemy.orm import Session
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from . import models, schemas
|
||||
|
||||
|
@ -9,10 +10,14 @@ def get_user(db: Session, name: str):
|
|||
.filter(models.User.name == name).first())
|
||||
|
||||
|
||||
def create_user(db: Session, user: schemas.UserCreate):
|
||||
def create_user(
|
||||
db: Session,
|
||||
user: schemas.UserCreate,
|
||||
crypt_context: CryptContext
|
||||
):
|
||||
db_user = models.User(
|
||||
name=user.name,
|
||||
password=user.password + "notreallyhashed",
|
||||
password=crypt_context.hash(user.password),
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from .config import PRODUCTION_MODE
|
||||
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
|
||||
load_config)
|
||||
from .routers import install
|
||||
|
||||
api = FastAPI(
|
||||
|
@ -29,6 +30,13 @@ app = FastAPI()
|
|||
app.mount("/api", api)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
if await has_config():
|
||||
config = await load_config()
|
||||
await config.connect_db()
|
||||
|
||||
|
||||
def main():
|
||||
uvicorn.run(
|
||||
"kiwi_vpn_api.main:app",
|
||||
|
|
39
api/kiwi_vpn_api/plan.md
Normal file
39
api/kiwi_vpn_api/plan.md
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Startup
|
||||
|
||||
if config file present:
|
||||
|
||||
- load config file
|
||||
- connect to DB
|
||||
- mount all routers
|
||||
|
||||
else:
|
||||
|
||||
- mount admin router
|
||||
|
||||
# PUT admin/config
|
||||
|
||||
if config file present:
|
||||
|
||||
- if user is admin:
|
||||
- overwrite config
|
||||
- reload config, reconnect to DB
|
||||
|
||||
else:
|
||||
|
||||
- overwrite config
|
||||
- reload config, connect to DB
|
||||
- mount all routers
|
||||
|
||||
# POST admin/user
|
||||
|
||||
if user table is empty:
|
||||
|
||||
- create new user
|
||||
- give "admin" cap to new user
|
||||
|
||||
else:
|
||||
|
||||
- if user is admin:
|
||||
- create new user
|
||||
|
||||
...
|
|
@ -3,7 +3,7 @@ from secrets import token_hex
|
|||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
|
||||
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
|
||||
load_config)
|
||||
from ..db import crud, schemas
|
||||
|
||||
|
@ -52,7 +52,7 @@ async def set_config(
|
|||
if config.jwt.secret is None:
|
||||
config.jwt.secret = token_hex(32)
|
||||
|
||||
await connect_db(config)
|
||||
await config.connect_db()
|
||||
|
||||
with open(CONFIG_FILE, "w") as kv:
|
||||
kv.write(config.json(indent=2))
|
||||
|
@ -82,17 +82,20 @@ async def check_db():
|
|||
async def create_db(
|
||||
admin_name: str,
|
||||
admin_password: str,
|
||||
config: BaseConfig = Depends(load_config),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# if await has_tables(db):
|
||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
|
||||
|
||||
# cryptContext = await config.crypto.cryptContext
|
||||
|
||||
if db is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
crud.create_user(db, schemas.UserCreate(
|
||||
name=admin_name,
|
||||
password=admin_password,
|
||||
crypt_context=config.crypt_context,
|
||||
))
|
||||
crud.add_user_capability(db, user_name=admin_name, capability="admin")
|
||||
|
|
|
@ -29,6 +29,7 @@ class User(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@router.get("/current", response_model=User)
|
||||
async def get_current_user(token: str = Depends(SCHEME)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
@ -48,15 +49,10 @@ async def get_current_user(token: str = Depends(SCHEME)):
|
|||
return user
|
||||
|
||||
|
||||
@router.get("/current_user/get", response_model=User)
|
||||
async def get_current_user(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
|
||||
|
||||
async def is_admin(current_user: User = Depends(get_current_user)):
|
||||
return ("admin" in current_user.capabilities)
|
||||
|
||||
|
||||
@router.get("/current_user/is_admin")
|
||||
@router.get("/current/is_admin")
|
||||
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
|
||||
return {"is_admin": is_admin}
|
||||
|
|
Loading…
Reference in a new issue