Compare commits

...

5 commits

6 changed files with 97 additions and 12 deletions

View file

@ -85,6 +85,15 @@ class Device(DeviceBase, table=True):
# device already existed # device already existed
return None return None
@classmethod
def get(cls, id: int) -> Device | None:
"""
Load device from database by id.
"""
with Connection.session as db:
return db.get(cls, id)
def update(self) -> None: def update(self) -> None:
""" """
Update this device in the database. Update this device in the database.

View file

@ -184,6 +184,7 @@ class User(UserBase, table=True):
return ( return (
capability in self.get_capabilities() capability in self.get_capabilities()
# admin can do everything
or UserCapabilityType.admin in self.get_capabilities() or UserCapabilityType.admin in self.get_capabilities()
) )
@ -201,3 +202,17 @@ class User(UserBase, table=True):
capability_name=capability.value, capability_name=capability.value,
) for capability in capabilities ) for capability in capabilities
] ]
def owns(
self,
device: Device,
) -> bool:
"""
Check if this user owns a device.
"""
return (
device.owner_name == self.name
# admin owns everything
or self.can(UserCapabilityType.admin)
)

View file

@ -7,7 +7,7 @@ from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import Config, Settings from ..config import Config, Settings
from ..db import User, UserCapabilityType from ..db import Device, User, UserCapabilityType
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate" tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
@ -52,6 +52,10 @@ class Responses:
"description": "Entry does not exist in database", "description": "Entry does not exist in database",
"content": None, "content": None,
} }
CANT_TARGET_SELF = {
"description": "Operation can't target yourself",
"content": None,
}
async def get_current_user( async def get_current_user(
@ -80,12 +84,12 @@ async def get_current_user_if_exists(
# fail if not requested by a user # fail if not requested by a user
if current_user is None: if current_user is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
return current_user return current_user
async def current_user_is_admin( async def get_current_user_if_admin(
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -95,6 +99,8 @@ async def current_user_is_admin(
if not current_user.can(UserCapabilityType.admin): if not current_user.can(UserCapabilityType.admin):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user
async def get_user_by_name( async def get_user_by_name(
user_name: str, user_name: str,
@ -122,3 +128,31 @@ async def get_user_by_name(
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return user return user
async def get_device_by_id(
device_id: int,
current_config: Config | None = Depends(Config.load),
) -> Device | None:
# can't connect to an unconfigured database
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
return Device.get(device_id)
async def get_device_by_id_if_editable(
device: Device | None = Depends(get_device_by_id),
current_user: User = Depends(get_current_user_if_exists),
) -> Device:
# fail if device doesn't exist
if device is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# fail if device is not owned by current user
if not current_user.owns(device):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return device

View file

@ -8,7 +8,7 @@ from sqlmodel import select
from ..config import Config from ..config import Config
from ..db import Connection, User, UserCapabilityType, UserCreate from ..db import Connection, User, UserCapabilityType, UserCreate
from ._common import Responses, current_user_is_admin from ._common import Responses, get_current_user_if_admin
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@ -63,7 +63,7 @@ async def create_initial_admin(
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# create an administrative user # create an administrative user
new_user = User.create(admin_user) new_user = User.create(user=admin_user)
new_user.set_capabilities([UserCapabilityType.admin]) new_user.set_capabilities([UserCapabilityType.admin])
new_user.update() new_user.update()
@ -79,7 +79,7 @@ async def create_initial_admin(
) )
async def set_config( async def set_config(
config: Config, config: Config,
_: User = Depends(current_user_is_admin), _: User = Depends(get_current_user_if_admin),
): ):
""" """
PUT ./config: Edit `kiwi-vpn` main config. PUT ./config: Edit `kiwi-vpn` main config.

View file

@ -5,7 +5,7 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..db import Device, DeviceCreate, DeviceRead, User from ..db import Device, DeviceCreate, DeviceRead, User
from ._common import Responses, get_user_by_name from ._common import Responses, get_device_by_id_if_editable, get_user_by_name
router = APIRouter(prefix="/device", tags=["device"]) router = APIRouter(prefix="/device", tags=["device"])
@ -42,3 +42,25 @@ async def add_device(
# return the created device on success # return the created device on success
return new_device return new_device
@router.delete(
"/{device_id}",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
},
response_model=User,
)
async def remove_device(
device: Device = Depends(get_device_by_id_if_editable),
):
"""
DELETE ./{device_id}: Remove a device from the database.
"""
# delete device
device.delete()

View file

@ -8,7 +8,7 @@ from pydantic import BaseModel
from ..config import Config from ..config import Config
from ..db import User, UserCapabilityType, UserCreate, UserRead from ..db import User, UserCapabilityType, UserCreate, UserRead
from ._common import (Responses, current_user_is_admin, from ._common import (Responses, get_current_user_if_admin,
get_current_user_if_exists, get_user_by_name) get_current_user_if_exists, get_user_by_name)
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -81,7 +81,7 @@ async def get_current_user(
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
_: User = Depends(current_user_is_admin), _: User = Depends(get_current_user_if_admin),
) -> User: ) -> User:
""" """
POST ./: Create a new user in the database. POST ./: Create a new user in the database.
@ -109,17 +109,22 @@ async def add_user(
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST, status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
status.HTTP_406_NOT_ACCEPTABLE: Responses.CANT_TARGET_SELF,
}, },
response_model=User, response_model=User,
) )
async def remove_user( async def remove_user(
_: User = Depends(current_user_is_admin), current_user: User = Depends(get_current_user_if_admin),
user: User = Depends(get_user_by_name), user: User = Depends(get_user_by_name),
): ):
""" """
DELETE ./{user_name}: Remove a user from the database. DELETE ./{user_name}: Remove a user from the database.
""" """
# stop inting
if current_user.name == user.name:
raise HTTPException(status_code=status.HTTP_406_NOT_ACCEPTABLE)
# delete user # delete user
user.delete() user.delete()
@ -135,7 +140,7 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
capabilities: list[UserCapabilityType], capabilities: list[UserCapabilityType],
_: User = Depends(current_user_is_admin), _: User = Depends(get_current_user_if_admin),
user: User = Depends(get_user_by_name), user: User = Depends(get_user_by_name),
): ):
""" """
@ -158,7 +163,7 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
capabilities: list[UserCapabilityType], capabilities: list[UserCapabilityType],
_: User = Depends(current_user_is_admin), _: User = Depends(get_current_user_if_admin),
user: User = Depends(get_user_by_name), user: User = Depends(get_user_by_name),
): ):
""" """