diff --git a/api/kiwi_vpn_api/routers/_deps.py b/api/kiwi_vpn_api/routers/_deps.py new file mode 100644 index 0000000..b82cc0e --- /dev/null +++ b/api/kiwi_vpn_api/routers/_deps.py @@ -0,0 +1,22 @@ +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer +from sqlalchemy.orm import Session + +from ..config import Config +from ..db import Connection, schemas + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/auth") + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: Session | None = Depends(Connection.get), + current_config: Config | None = Depends(Config.load), +): + if current_config is None: + return None + + username = await current_config.jwt.decode_token(token) + user = schemas.User.from_db(db, username) + + return user diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 7f98bc9..5d47246 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from ..config import Config from ..db import Connection, schemas +from . import _deps router = APIRouter(prefix="/admin") @@ -24,9 +25,14 @@ router = APIRouter(prefix="/admin") async def set_config( new_config: Config, current_config: Config | None = Depends(Config.load), + current_user: schemas.User | None = Depends(_deps.get_current_user), ): + print(current_config, current_user) + if current_config is not None: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + # server is configured, needs authorization + if current_user is None or "admin" not in current_user.capabilities: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if new_config.jwt.secret is None: new_config.jwt.secret = token_hex(32) @@ -42,7 +48,7 @@ async def set_config( "content": None, }, status.HTTP_400_BAD_REQUEST: { - "description": "Database doesn't exist", + "description": "Server is not configured", "content": None, }, }, diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 3acd3dc..ce3397d 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -1,16 +1,13 @@ - from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from fastapi.security import OAuth2PasswordRequestForm from pydantic import BaseModel from sqlalchemy.orm import Session from ..config import Config from ..db import Connection, schemas +from . import _deps router = APIRouter(prefix="/user") -SCHEME = OAuth2PasswordBearer( - tokenUrl=f".{router.prefix}/auth", -) class Token(BaseModel): @@ -21,14 +18,17 @@ class Token(BaseModel): @router.post("/auth", response_model=Token) async def login( form_data: OAuth2PasswordRequestForm = Depends(), - config: Config = Depends(Config.load), - db: Session = Depends(Connection.get), + current_config: Config | None = Depends(Config.load), + db: Session | None = Depends(Connection.get), ): + if current_config is None: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + user = schemas.User.login( db=db, name=form_data.username, password=form_data.password, - crypt_context=await config.crypto.crypt_context, + crypt_context=await current_config.crypto.crypt_context, ) if user is None: @@ -38,30 +38,12 @@ async def login( headers={"WWW-Authenticate": "Bearer"}, ) - access_token = await config.jwt.create_token(user.name) + access_token = await current_config.jwt.create_token(user.name) return {"access_token": access_token, "token_type": "bearer"} -async def dep_get_current_user( - token: str = Depends(SCHEME), - db: Session = Depends(Connection.get), - config: Config = Depends(Config.load), -): - username = await config.jwt.decode_token(token) - user = schemas.User.from_db(db, username) - - return user - - @router.get("/current", response_model=schemas.User) async def get_current_user( - current_user: schemas.User = Depends(dep_get_current_user), + current_user: schemas.User | None = Depends(_deps.get_current_user), ): - if current_user is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - return current_user