Compare commits

..

No commits in common. "e6fe35d14eaa38249b5272657b89f4059d9b2e34" and "2d755b8e3d8b1cf7ee9cf36e9c5af6d6cd4c8477" have entirely different histories.

6 changed files with 65 additions and 35 deletions

View file

@ -9,6 +9,7 @@ Pydantic models might have convenience methods attached.
from __future__ import annotations from __future__ import annotations
import functools
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
@ -35,14 +36,25 @@ class Settings(BaseSettings):
docs_url: str | None = "/docs" docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc" redoc_url: str | None = "/redoc"
@classmethod
@functools.lru_cache
def load(cls) -> Settings:
return cls()
@classmethod
@property
def _(cls) -> Settings:
"""
Shorthand for load()
"""
return cls.load()
@property @property
def config_file(self) -> Path: def config_file(self) -> Path:
return self.data_dir.joinpath(self.config_file_name) return self.data_dir.joinpath(self.config_file_name)
SETTINGS = Settings()
class DBType(Enum): class DBType(Enum):
""" """
Supported database types Supported database types
@ -61,7 +73,7 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | Path | None = SETTINGS.data_dir.joinpath("kiwi-vpn.db") database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db"))
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@ -241,7 +253,7 @@ class Config(BaseModel):
crypto: CryptoConfig crypto: CryptoConfig
server_dn: ServerDN server_dn: ServerDN
__instance: Config | None = None __singleton: Config | None = None
@classmethod @classmethod
def load(cls) -> Config | None: def load(cls) -> Config | None:
@ -249,15 +261,16 @@ class Config(BaseModel):
Load configuration from config file Load configuration from config file
""" """
if cls.__instance is None: if cls.__singleton is not None:
try: return cls.__singleton
with open(SETTINGS.config_file, "r") as config_file:
cls.__instance = cls.parse_obj(json.load(config_file))
except FileNotFoundError: try:
pass with open(Settings._.config_file, "r") as config_file:
cls.__singleton = Config.parse_obj(json.load(config_file))
return cls.__singleton
return cls.__instance except FileNotFoundError:
return None
@classmethod @classmethod
@property @property
@ -267,7 +280,7 @@ class Config(BaseModel):
""" """
if (config := cls.load()) is None: if (config := cls.load()) is None:
raise FileNotFoundError(SETTINGS.config_file) raise FileNotFoundError(Settings._.config_file)
return config return config
@ -276,5 +289,5 @@ class Config(BaseModel):
Save configuration to config file Save configuration to config file
""" """
with open(SETTINGS.config_file, "w") as config_file: with open(Settings._.config_file, "w") as config_file:
config_file.write(self.json(indent=2)) config_file.write(self.json(indent=2))

View file

@ -4,6 +4,7 @@ Python interface to EasyRSA CA.
from __future__ import annotations from __future__ import annotations
import functools
import subprocess import subprocess
from enum import Enum, auto from enum import Enum, auto
from pathlib import Path from pathlib import Path
@ -12,7 +13,7 @@ from cryptography import x509
from passlib import pwd from passlib import pwd
from pydantic import BaseModel from pydantic import BaseModel
from .config import SETTINGS, Config, KeyAlgorithm from .config import Config, KeyAlgorithm, Settings
from .db import Connection, Device from .db import Connection, Device
@ -139,13 +140,27 @@ class EasyRSA:
None: {}, None: {},
} }
@classmethod
@functools.lru_cache
def _load(cls) -> EasyRSA:
return cls()
@classmethod
@property
def _(cls) -> EasyRSA:
"""
Get the singleton
"""
return cls._load()
@property @property
def output_directory(self) -> Path: def output_directory(self) -> Path:
""" """
Where certificates are stored Where certificates are stored
""" """
return SETTINGS.data_dir.joinpath("pki") return Settings._.data_dir.joinpath("pki")
@property @property
def ca_password(self) -> str: def ca_password(self) -> str:
@ -285,14 +300,14 @@ class EasyRSA:
) )
EASYRSA = EasyRSA()
# some basic test # some basic test
if __name__ == "__main__": if __name__ == "__main__":
ca = EASYRSA.build_ca() easy_rsa = EasyRSA()
server = EASYRSA.issue(CertificateType.server) easy_rsa.init_pki()
ca = easy_rsa.build_ca()
server = easy_rsa.issue(CertificateType.server)
client = None client = None
# check if configured # check if configured
@ -301,7 +316,7 @@ if __name__ == "__main__":
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)
if (device := Device.get(1)) is not None: if (device := Device.get(1)) is not None:
client = EASYRSA.issue( client = easy_rsa.issue(
dn=DistinguishedName.build(device) dn=DistinguishedName.build(device)
) )

View file

@ -12,10 +12,12 @@ If run directly, uses `uvicorn` to run the app.
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import SETTINGS, Config from .config import Config, Settings
from .db import Connection, User from .db import Connection, User
from .routers import main_router from .routers import main_router
settings = Settings._
app = FastAPI( app = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",
@ -27,9 +29,9 @@ app = FastAPI(
"name": "MIT License", "name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php", "url": "https://opensource.org/licenses/mit-license.php",
}, },
openapi_url=SETTINGS.openapi_url, openapi_url=settings.openapi_url,
docs_url=SETTINGS.docs_url if not SETTINGS.production_mode else None, docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=SETTINGS.redoc_url if not SETTINGS.production_mode else None, redoc_url=settings.redoc_url if not settings.production_mode else None,
) )
app.include_router(main_router) app.include_router(main_router)

View file

@ -6,10 +6,10 @@ This file: Main API router definition.
from fastapi import APIRouter from fastapi import APIRouter
from ..config import SETTINGS from ..config import Settings
from . import admin, device, service, user from . import admin, device, service, user
main_router = APIRouter(prefix=f"/{SETTINGS.api_v1_prefix}") main_router = APIRouter(prefix=f"/{Settings._.api_v1_prefix}")
main_router.include_router(admin.router) main_router.include_router(admin.router)
main_router.include_router(service.router) main_router.include_router(service.router)

View file

@ -5,11 +5,11 @@ Common dependencies for routers.
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from ..config import SETTINGS, Config from ..config import Config, Settings
from ..db import Device, User from ..db import Device, User
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{SETTINGS.api_v1_prefix}/user/authenticate" tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
) )

View file

@ -5,7 +5,7 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..db import Device, DeviceCreate, DeviceRead, User from ..db import Device, DeviceCreate, DeviceRead, User
from ..easyrsa import EASYRSA, DistinguishedName from ..easyrsa import DistinguishedName, EasyRSA
from ._common import (Responses, get_current_user, get_device_by_id, from ._common import (Responses, get_current_user, get_device_by_id,
get_user_by_name) get_user_by_name)
@ -90,19 +90,19 @@ async def remove_device(
}, },
response_model=DeviceRead, response_model=DeviceRead,
) )
async def request_certificate_issuance( async def request_certificate(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
device: Device = Depends(get_device_by_id), device: Device = Depends(get_device_by_id),
) -> Device: ) -> Device:
""" """
POST ./{device_id}/issue: Request certificate issuance for a device. POST ./{device_id}/issue: Request certificate for a device.
""" """
# check permission # check permission
if not current_user.can_edit(device): if not current_user.can_edit(device):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# can only request for a newly created device # cannot request for a newly created device
if device.approved is not None: if device.approved is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
@ -111,7 +111,7 @@ async def request_certificate_issuance(
if device.approved: if device.approved:
# issue the certificate immediately # issue the certificate immediately
if (certificate := EASYRSA.issue( if (certificate := EasyRSA._.issue(
dn=DistinguishedName.build(device) dn=DistinguishedName.build(device)
)) is not None: )) is not None:
device.expiry = certificate.not_valid_after device.expiry = certificate.not_valid_after