Compare commits

..

4 commits

7 changed files with 45 additions and 62 deletions

View file

@ -29,14 +29,16 @@ class Settings(BaseSettings):
production_mode: bool = False production_mode: bool = False
data_dir: Path = Path("./tmp") data_dir: Path = Path("./tmp")
api_v1_prefix: str = "api/v1"
openapi_url: str = "/openapi.json" openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs" docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc" redoc_url: str | None = "/redoc"
@staticmethod @classmethod
@property
@functools.lru_cache @functools.lru_cache
def get() -> Settings: def _(cls) -> Settings:
return Settings() return cls()
@property @property
def config_file(self) -> Path: def config_file(self) -> Path:
@ -61,7 +63,7 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | None = Settings.get().data_dir.joinpath("vpn.db") database: str | None = Settings._.data_dir.joinpath("vpn.db")
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@ -201,7 +203,7 @@ class Config(BaseModel):
return cls.__singleton return cls.__singleton
try: try:
with open(Settings.get().config_file, "r") as config_file: with open(Settings._.config_file, "r") as config_file:
cls.__singleton = Config.parse_obj(json.load(config_file)) cls.__singleton = Config.parse_obj(json.load(config_file))
return cls.__singleton return cls.__singleton
@ -222,5 +224,5 @@ class Config(BaseModel):
Save configuration to config file Save configuration to config file
""" """
with open(Settings.get().config_file, "w") as config_file: with open(Settings._.config_file, "w") as config_file:
config_file.write(self.json(indent=2)) config_file.write(self.json(indent=2))

View file

@ -100,7 +100,7 @@ class User(UserBase, table=True):
db.commit() db.commit()
db.refresh(self) db.refresh(self)
def delete(self) -> bool: def delete(self) -> None:
""" """
Delete this user from the database. Delete this user from the database.
""" """

View file

@ -16,9 +16,6 @@ from .config import Config, Settings
from .db import Connection, User from .db import Connection, User
from .routers import main_router from .routers import main_router
settings = Settings.get()
app = FastAPI( app = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",
@ -30,12 +27,12 @@ app = FastAPI(
"name": "MIT License", "name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php", "url": "https://opensource.org/licenses/mit-license.php",
}, },
openapi_url=settings.openapi_url, openapi_url=Settings._.openapi_url,
docs_url=settings.docs_url if not settings.production_mode else None, docs_url=Settings._.docs_url if not Settings._.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None, redoc_url=Settings._.redoc_url if not Settings._.production_mode else None,
) )
app.include_router(main_router) app.include_router(main_router, prefix=f"/{Settings._.api_v1_prefix}")
@app.on_event("startup") @app.on_event("startup")

View file

@ -1,12 +1,10 @@
from fastapi import APIRouter from fastapi import APIRouter
from . import admin from . import admin, user
# from . import user main_router = APIRouter()
main_router = APIRouter(prefix="/api/v1")
main_router.include_router(admin.router) main_router.include_router(admin.router)
# main_router.include_router(user.router) main_router.include_router(user.router)
__all__ = ["main_router"] __all__ = ["main_router"]

View file

@ -6,10 +6,12 @@ Common dependencies for routers.
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import Config from ..config import Config, Settings
from ..db import Capability, User from ..db import Capability, User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
)
class Responses: class Responses:

View file

@ -78,7 +78,7 @@ async def create_initial_admin(
) )
async def set_config( async def set_config(
config: Config, config: Config,
_: User | None = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """
PUT ./config: Edit `kiwi-vpn` main config. PUT ./config: Edit `kiwi-vpn` main config.

View file

@ -5,11 +5,9 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Capability, User, UserCreate, UserRead
from ..db.schemata import User, UserCapability, UserCreate
from ._common import Responses, get_current_user, get_current_user_if_admin from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -28,7 +26,6 @@ class Token(BaseModel):
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./authenticate: Authenticate a user. Issues a bearer token. POST ./authenticate: Authenticate a user. Issues a bearer token.
@ -39,12 +36,10 @@ async def login(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# try logging in # try logging in
user = User(name=form_data.username) if not (user := User.authenticate(
if not user.authenticate( name=form_data.username,
db=db,
password=form_data.password, password=form_data.password,
crypt_context=current_config.crypto.crypt_context, )):
):
# authentication failed # authentication failed
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -57,7 +52,7 @@ async def login(
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.get("/current", response_model=User) @router.get("/current", response_model=UserRead)
async def get_current_user( async def get_current_user(
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
): ):
@ -81,20 +76,14 @@ async def get_current_user(
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
current_config: Config | None = Depends(Config.load),
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./: Create a new user in the database. POST ./: Create a new user in the database.
""" """
# actually create the new user # actually create the new user
new_user = User.create( new_user = User.create(**user.dict())
db=db,
user=user,
crypt_context=current_config.crypto.crypt_context,
)
# fail if creation was unsuccessful # fail if creation was unsuccessful
if new_user is None: if new_user is None:
@ -118,22 +107,21 @@ async def add_user(
async def remove_user( async def remove_user(
user_name: str, user_name: str,
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}: Remove a user from the database. DELETE ./{user_name}: Remove a user from the database.
""" """
# get the user # get the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name,
)
# fail if deletion was unsuccessful # fail if user not found
if user is None or not user.delete(db): if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# delete user
user.delete()
@router.post( @router.post(
"/{user_name}/capabilities", "/{user_name}/capabilities",
@ -146,22 +134,21 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[Capability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./{user_name}/capabilities: Add capabilities to a user. POST ./{user_name}/capabilities: Add capabilities to a user.
""" """
# get and change the user # get and change the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name, user.set_capabilities(
user.get_capabilities() | set(capabilities)
) )
user.capabilities.extend(capabilities) user.update()
user.update(db)
@router.delete( @router.delete(
@ -175,21 +162,18 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[Capability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}/capabilities: Remove capabilities from a user. DELETE ./{user_name}/capabilities: Remove capabilities from a user.
""" """
# get and change the user # get and change the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name, user.set_capabilities(
user.get_capabilities() - set(capabilities)
) )
for capability in capabilities: user.update()
user.capabilities.remove(capability)
user.update(db)