Compare commits

...

4 commits

7 changed files with 64 additions and 19 deletions

View file

@ -143,8 +143,7 @@ class User(UserBase, table=True):
# password hash mismatch # password hash mismatch
return None return None
if not (user.has_tag(TagValue.login) if not (user.has_tag(TagValue.login) or user.is_admin):
or user.has_tag(TagValue.admin)):
# no login permission # no login permission
return None return None
@ -169,7 +168,8 @@ class User(UserBase, table=True):
db.delete(self) db.delete(self)
db.commit() db.commit()
def _get_tags(self) -> Iterable[TagValue]: @property
def __tags(self) -> Iterable[TagValue]:
""" """
Return the tags of this user. Return the tags of this user.
""" """
@ -184,7 +184,15 @@ class User(UserBase, table=True):
Check if this user has a tag. Check if this user has a tag.
""" """
return tag in self._get_tags() return tag in self.__tags
@property
def is_admin(self) -> bool:
"""
Shorthand for checking if this user has the `admin` tag.
"""
return TagValue.admin in self.__tags
def add_tags( def add_tags(
self, self,
@ -196,7 +204,7 @@ class User(UserBase, table=True):
self.tags = [ self.tags = [
tag._(self) tag._(self)
for tag in (set(self._get_tags()) | set(tags)) for tag in (set(self.__tags) | set(tags))
] ]
def remove_tags( def remove_tags(
@ -209,7 +217,7 @@ class User(UserBase, table=True):
self.tags = [ self.tags = [
tag._(self) tag._(self)
for tag in (set(self._get_tags()) - set(tags)) for tag in (set(self.__tags) - set(tags))
] ]
def can_edit( def can_edit(
@ -221,7 +229,7 @@ class User(UserBase, table=True):
""" """
# admin can "edit" everything # admin can "edit" everything
if self.has_tag(TagValue.admin): if self.is_admin:
return True return True
# user can "edit" itself # user can "edit" itself
@ -240,7 +248,7 @@ class User(UserBase, table=True):
""" """
# only admin can "admin" anything # only admin can "admin" anything
if not self.has_tag(TagValue.admin): if not self.is_admin:
return False return False
# admin canot "admin itself"! # admin canot "admin itself"!
@ -264,7 +272,7 @@ class User(UserBase, table=True):
return False return False
# admin can "create" everything # admin can "create" everything
if self.has_tag(TagValue.admin): if self.is_admin:
return True return True
# user can only create devices for itself # user can only create devices for itself

View file

@ -6,6 +6,7 @@ from __future__ import annotations
import subprocess import subprocess
from datetime import datetime from datetime import datetime
from enum import Enum, auto
from pathlib import Path from pathlib import Path
from OpenSSL import crypto from OpenSSL import crypto
@ -97,6 +98,19 @@ class DistinguishedName(BaseModel):
] ]
class CertificateType(Enum):
"""
Possible types of certificates
"""
ca = auto()
client = auto()
server = auto()
def __str__(self) -> str:
return self._name_
class EasyRSA: class EasyRSA:
""" """
Represents an EasyRSA PKI. Represents an EasyRSA PKI.
@ -225,6 +239,7 @@ class EasyRSA:
Path("ca.crt"), Path("ca.crt"),
Config._.crypto.ca_expiry_days, Config._.crypto.ca_expiry_days,
"--dn-mode=cn_only",
"--req-cn=kiwi-vpn-ca", "--req-cn=kiwi-vpn-ca",
"build-ca", "build-ca",
@ -236,13 +251,17 @@ class EasyRSA:
def issue( def issue(
self, self,
cert_type: str = "client", cert_type: CertificateType = CertificateType.client,
dn: DistinguishedName = DistinguishedName.build(), dn: DistinguishedName = DistinguishedName.build(),
) -> crypto.X509: ) -> crypto.X509 | None:
""" """
Issue a client or server certificate Issue a client or server certificate
""" """
if not (cert_type is CertificateType.client
or cert_type is CertificateType.server):
return None
return self.__build_cert( return self.__build_cert(
Path(f"issued/{dn.common_name}.crt"), Path(f"issued/{dn.common_name}.crt"),
Config._.crypto.cert_expiry_days, Config._.crypto.cert_expiry_days,
@ -262,7 +281,7 @@ if __name__ == "__main__":
easy_rsa.init_pki() easy_rsa.init_pki()
ca = easy_rsa.build_ca() ca = easy_rsa.build_ca()
server = easy_rsa.issue("server") server = easy_rsa.issue(CertificateType.server)
client = None client = None
# check if configured # check if configured
@ -275,7 +294,7 @@ if __name__ == "__main__":
db.add(device) db.add(device)
dn = DistinguishedName.build(device) dn = DistinguishedName.build(device)
client = easy_rsa.issue("client", dn) client = easy_rsa.issue(dn=dn)
date_format, encoding = "%Y%m%d%H%M%SZ", "ascii" date_format, encoding = "%Y%m%d%H%M%SZ", "ascii"

View file

@ -39,6 +39,10 @@ class Responses:
"description": "Operation not permitted", "description": "Operation not permitted",
"content": None, "content": None,
} }
ENTRY_ADDED = {
"description": "Entry added to database",
"content": None,
}
ENTRY_EXISTS = { ENTRY_EXISTS = {
"description": "Entry exists in database", "description": "Entry exists in database",
"content": None, "content": None,

View file

@ -83,7 +83,7 @@ async def set_config(
""" """
# check permissions # check permissions
if not current_user.has_tag(TagValue.admin): if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database # update config file, reconnect to database

View file

@ -14,7 +14,7 @@ router = APIRouter(prefix="/device", tags=["device"])
@router.post( @router.post(
"/{user_name}", "/{user_name}",
responses={ responses={
status.HTTP_200_OK: Responses.OK, status.HTTP_201_CREATED: Responses.ENTRY_ADDED,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED, status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION, status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION,
@ -22,6 +22,7 @@ router = APIRouter(prefix="/device", tags=["device"])
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS, status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
}, },
response_model=DeviceRead, response_model=DeviceRead,
status_code=status.HTTP_201_CREATED,
) )
async def add_device( async def add_device(
device: DeviceCreate, device: DeviceCreate,

View file

@ -5,7 +5,9 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..config import Config from ..config import Config
from ._common import Responses, get_current_config from ..db import User
from ..easyrsa import CertificateType, EasyRSA
from ._common import Responses, get_current_config, get_current_user
router = APIRouter(prefix="/service", tags=["service"]) router = APIRouter(prefix="/service", tags=["service"])
@ -20,5 +22,14 @@ router = APIRouter(prefix="/service", tags=["service"])
) )
async def init_pki( async def init_pki(
_: Config = Depends(get_current_config), _: Config = Depends(get_current_config),
current_user: User = Depends(get_current_user),
) -> None: ) -> None:
pass
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
easy_rsa = EasyRSA()
easy_rsa.init_pki()
easy_rsa.build_ca()
easy_rsa.issue(CertificateType.server)

View file

@ -63,13 +63,14 @@ async def get_current_user_route(
@router.post( @router.post(
"", "",
responses={ responses={
status.HTTP_200_OK: Responses.OK, status.HTTP_201_CREATED: Responses.ENTRY_ADDED,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED, status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION, status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS, status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
}, },
response_model=UserRead, response_model=UserRead,
status_code=status.HTTP_201_CREATED,
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
@ -127,11 +128,12 @@ async def remove_user(
@router.post( @router.post(
"/{user_name}/tags", "/{user_name}/tags",
responses={ responses={
status.HTTP_200_OK: Responses.OK, status.HTTP_201_CREATED: Responses.ENTRY_ADDED,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED, status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION, status.HTTP_403_FORBIDDEN: Responses.NEEDS_PERMISSION,
}, },
status_code=status.HTTP_201_CREATED,
) )
async def extend_tags( async def extend_tags(
tags: list[TagValue], tags: list[TagValue],