dirty commit
This commit is contained in:
parent
641dfd7ba0
commit
e34e669f79
6 changed files with 101 additions and 35 deletions
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from jose.constants import ALGORITHMS
|
from jose.constants import ALGORITHMS
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
@ -50,6 +51,30 @@ class BaseConfig(BaseModel):
|
||||||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def crypt_context(self) -> CryptContext:
|
||||||
|
return CryptContext(
|
||||||
|
schemes=self.crypto.schemes,
|
||||||
|
deprecated="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
__session_local: sessionmaker = PrivateAttr()
|
||||||
|
|
||||||
|
async def connect_db(self) -> None:
|
||||||
|
engine = create_engine(
|
||||||
|
"sqlite:///./tmp/vpn.db",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
)
|
||||||
|
self.__session_local = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=engine,
|
||||||
|
)
|
||||||
|
ORMBaseModel.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database(self) -> Session | None:
|
||||||
|
if self.__session_local is not None:
|
||||||
|
return self.__session_local()
|
||||||
|
|
||||||
|
|
||||||
CONFIG_FILE = "tmp/config.json"
|
CONFIG_FILE = "tmp/config.json"
|
||||||
|
|
||||||
|
@ -67,24 +92,14 @@ async def load_config() -> BaseConfig:
|
||||||
return BaseConfig()
|
return BaseConfig()
|
||||||
|
|
||||||
|
|
||||||
async def connect_db(config: BaseConfig = Depends(load_config)) -> None:
|
async def get_db(
|
||||||
global SESSION_LOCAL
|
config: BaseConfig = Depends(load_config)
|
||||||
|
) -> Generator[Session | None, None, None]:
|
||||||
|
if db := config.database is None:
|
||||||
|
yield None
|
||||||
|
|
||||||
engine = create_engine(
|
else:
|
||||||
"sqlite:///./tmp/vpn.db",
|
try:
|
||||||
connect_args={"check_same_thread": False},
|
yield db
|
||||||
)
|
finally:
|
||||||
SESSION_LOCAL = sessionmaker(
|
db.close()
|
||||||
autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
ORMBaseModel.metadata.create_all(bind=engine)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_db() -> Session:
|
|
||||||
global SESSION_LOCAL
|
|
||||||
|
|
||||||
db = SESSION_LOCAL()
|
|
||||||
try:
|
|
||||||
yield db
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
from . import models, schemas
|
from . import models, schemas
|
||||||
|
|
||||||
|
@ -9,10 +10,14 @@ def get_user(db: Session, name: str):
|
||||||
.filter(models.User.name == name).first())
|
.filter(models.User.name == name).first())
|
||||||
|
|
||||||
|
|
||||||
def create_user(db: Session, user: schemas.UserCreate):
|
def create_user(
|
||||||
|
db: Session,
|
||||||
|
user: schemas.UserCreate,
|
||||||
|
crypt_context: CryptContext
|
||||||
|
):
|
||||||
db_user = models.User(
|
db_user = models.User(
|
||||||
name=user.name,
|
name=user.name,
|
||||||
password=user.password + "notreallyhashed",
|
password=crypt_context.hash(user.password),
|
||||||
)
|
)
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import Depends, FastAPI
|
||||||
|
|
||||||
from .config import PRODUCTION_MODE
|
from .config import (PRODUCTION_MODE, BaseConfig, has_config,
|
||||||
|
load_config)
|
||||||
from .routers import install
|
from .routers import install
|
||||||
|
|
||||||
api = FastAPI(
|
api = FastAPI(
|
||||||
|
@ -29,6 +30,13 @@ app = FastAPI()
|
||||||
app.mount("/api", api)
|
app.mount("/api", api)
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def on_startup():
|
||||||
|
if await has_config():
|
||||||
|
config = await load_config()
|
||||||
|
await config.connect_db()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"kiwi_vpn_api.main:app",
|
"kiwi_vpn_api.main:app",
|
||||||
|
|
39
api/kiwi_vpn_api/plan.md
Normal file
39
api/kiwi_vpn_api/plan.md
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
# Startup
|
||||||
|
|
||||||
|
if config file present:
|
||||||
|
|
||||||
|
- load config file
|
||||||
|
- connect to DB
|
||||||
|
- mount all routers
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- mount admin router
|
||||||
|
|
||||||
|
# PUT admin/config
|
||||||
|
|
||||||
|
if config file present:
|
||||||
|
|
||||||
|
- if user is admin:
|
||||||
|
- overwrite config
|
||||||
|
- reload config, reconnect to DB
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- overwrite config
|
||||||
|
- reload config, connect to DB
|
||||||
|
- mount all routers
|
||||||
|
|
||||||
|
# POST admin/user
|
||||||
|
|
||||||
|
if user table is empty:
|
||||||
|
|
||||||
|
- create new user
|
||||||
|
- give "admin" cap to new user
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- if user is admin:
|
||||||
|
- create new user
|
||||||
|
|
||||||
|
...
|
|
@ -3,7 +3,7 @@ from secrets import token_hex
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import (CONFIG_FILE, BaseConfig, connect_db, get_db, has_config,
|
from ..config import (CONFIG_FILE, BaseConfig, get_db, has_config,
|
||||||
load_config)
|
load_config)
|
||||||
from ..db import crud, schemas
|
from ..db import crud, schemas
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ async def set_config(
|
||||||
if config.jwt.secret is None:
|
if config.jwt.secret is None:
|
||||||
config.jwt.secret = token_hex(32)
|
config.jwt.secret = token_hex(32)
|
||||||
|
|
||||||
await connect_db(config)
|
await config.connect_db()
|
||||||
|
|
||||||
with open(CONFIG_FILE, "w") as kv:
|
with open(CONFIG_FILE, "w") as kv:
|
||||||
kv.write(config.json(indent=2))
|
kv.write(config.json(indent=2))
|
||||||
|
@ -82,17 +82,20 @@ async def check_db():
|
||||||
async def create_db(
|
async def create_db(
|
||||||
admin_name: str,
|
admin_name: str,
|
||||||
admin_password: str,
|
admin_password: str,
|
||||||
|
config: BaseConfig = Depends(load_config),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
# if await has_tables(db):
|
# if await has_tables(db):
|
||||||
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
# raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# db.create_tables([Certificate, DistinguishedName, User, UserCapability])
|
|
||||||
|
|
||||||
# cryptContext = await config.crypto.cryptContext
|
# cryptContext = await config.crypto.cryptContext
|
||||||
|
|
||||||
|
if db is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
crud.create_user(db, schemas.UserCreate(
|
crud.create_user(db, schemas.UserCreate(
|
||||||
name=admin_name,
|
name=admin_name,
|
||||||
password=admin_password,
|
password=admin_password,
|
||||||
|
crypt_context=config.crypt_context,
|
||||||
))
|
))
|
||||||
crud.add_user_capability(db, user_name=admin_name, capability="admin")
|
crud.add_user_capability(db, user_name=admin_name, capability="admin")
|
||||||
|
|
|
@ -29,6 +29,7 @@ class User(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/current", response_model=User)
|
||||||
async def get_current_user(token: str = Depends(SCHEME)):
|
async def get_current_user(token: str = Depends(SCHEME)):
|
||||||
credentials_exception = HTTPException(
|
credentials_exception = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
@ -48,15 +49,10 @@ async def get_current_user(token: str = Depends(SCHEME)):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current_user/get", response_model=User)
|
|
||||||
async def get_current_user(current_user: User = Depends(get_current_user)):
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
async def is_admin(current_user: User = Depends(get_current_user)):
|
async def is_admin(current_user: User = Depends(get_current_user)):
|
||||||
return ("admin" in current_user.capabilities)
|
return ("admin" in current_user.capabilities)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current_user/is_admin")
|
@router.get("/current/is_admin")
|
||||||
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
|
async def current_user_is_admin(is_admin: bool = Depends(is_admin)):
|
||||||
return {"is_admin": is_admin}
|
return {"is_admin": is_admin}
|
||||||
|
|
Loading…
Reference in a new issue