Compare commits
3 commits
2d755b8e3d
...
e6fe35d14e
| Author | SHA1 | Date | |
|---|---|---|---|
| e6fe35d14e | |||
| c0388d58c1 | |||
| bca5b2b55c |
6 changed files with 35 additions and 65 deletions
|
|
@ -9,7 +9,6 @@ Pydantic models might have convenience methods attached.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
@ -36,25 +35,14 @@ class Settings(BaseSettings):
|
||||||
docs_url: str | None = "/docs"
|
docs_url: str | None = "/docs"
|
||||||
redoc_url: str | None = "/redoc"
|
redoc_url: str | None = "/redoc"
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@functools.lru_cache
|
|
||||||
def load(cls) -> Settings:
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@property
|
|
||||||
def _(cls) -> Settings:
|
|
||||||
"""
|
|
||||||
Shorthand for load()
|
|
||||||
"""
|
|
||||||
|
|
||||||
return cls.load()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def config_file(self) -> Path:
|
def config_file(self) -> Path:
|
||||||
return self.data_dir.joinpath(self.config_file_name)
|
return self.data_dir.joinpath(self.config_file_name)
|
||||||
|
|
||||||
|
|
||||||
|
SETTINGS = Settings()
|
||||||
|
|
||||||
|
|
||||||
class DBType(Enum):
|
class DBType(Enum):
|
||||||
"""
|
"""
|
||||||
Supported database types
|
Supported database types
|
||||||
|
|
@ -73,7 +61,7 @@ class DBConfig(BaseModel):
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
host: str | None = None
|
host: str | None = None
|
||||||
database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db"))
|
database: str | Path | None = SETTINGS.data_dir.joinpath("kiwi-vpn.db")
|
||||||
|
|
||||||
mysql_driver: str = "pymysql"
|
mysql_driver: str = "pymysql"
|
||||||
mysql_args: list[str] = ["charset=utf8mb4"]
|
mysql_args: list[str] = ["charset=utf8mb4"]
|
||||||
|
|
@ -253,7 +241,7 @@ class Config(BaseModel):
|
||||||
crypto: CryptoConfig
|
crypto: CryptoConfig
|
||||||
server_dn: ServerDN
|
server_dn: ServerDN
|
||||||
|
|
||||||
__singleton: Config | None = None
|
__instance: Config | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls) -> Config | None:
|
def load(cls) -> Config | None:
|
||||||
|
|
@ -261,16 +249,15 @@ class Config(BaseModel):
|
||||||
Load configuration from config file
|
Load configuration from config file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if cls.__singleton is not None:
|
if cls.__instance is None:
|
||||||
return cls.__singleton
|
try:
|
||||||
|
with open(SETTINGS.config_file, "r") as config_file:
|
||||||
|
cls.__instance = cls.parse_obj(json.load(config_file))
|
||||||
|
|
||||||
try:
|
except FileNotFoundError:
|
||||||
with open(Settings._.config_file, "r") as config_file:
|
pass
|
||||||
cls.__singleton = Config.parse_obj(json.load(config_file))
|
|
||||||
return cls.__singleton
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
return cls.__instance
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@property
|
@property
|
||||||
|
|
@ -280,7 +267,7 @@ class Config(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (config := cls.load()) is None:
|
if (config := cls.load()) is None:
|
||||||
raise FileNotFoundError(Settings._.config_file)
|
raise FileNotFoundError(SETTINGS.config_file)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
@ -289,5 +276,5 @@ class Config(BaseModel):
|
||||||
Save configuration to config file
|
Save configuration to config file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with open(Settings._.config_file, "w") as config_file:
|
with open(SETTINGS.config_file, "w") as config_file:
|
||||||
config_file.write(self.json(indent=2))
|
config_file.write(self.json(indent=2))
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ Python interface to EasyRSA CA.
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -13,7 +12,7 @@ from cryptography import x509
|
||||||
from passlib import pwd
|
from passlib import pwd
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import Config, KeyAlgorithm, Settings
|
from .config import SETTINGS, Config, KeyAlgorithm
|
||||||
from .db import Connection, Device
|
from .db import Connection, Device
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,27 +139,13 @@ class EasyRSA:
|
||||||
None: {},
|
None: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@functools.lru_cache
|
|
||||||
def _load(cls) -> EasyRSA:
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@property
|
|
||||||
def _(cls) -> EasyRSA:
|
|
||||||
"""
|
|
||||||
Get the singleton
|
|
||||||
"""
|
|
||||||
|
|
||||||
return cls._load()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_directory(self) -> Path:
|
def output_directory(self) -> Path:
|
||||||
"""
|
"""
|
||||||
Where certificates are stored
|
Where certificates are stored
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return Settings._.data_dir.joinpath("pki")
|
return SETTINGS.data_dir.joinpath("pki")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ca_password(self) -> str:
|
def ca_password(self) -> str:
|
||||||
|
|
@ -300,14 +285,14 @@ class EasyRSA:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
EASYRSA = EasyRSA()
|
||||||
|
|
||||||
|
|
||||||
# some basic test
|
# some basic test
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
easy_rsa = EasyRSA()
|
ca = EASYRSA.build_ca()
|
||||||
easy_rsa.init_pki()
|
server = EASYRSA.issue(CertificateType.server)
|
||||||
|
|
||||||
ca = easy_rsa.build_ca()
|
|
||||||
server = easy_rsa.issue(CertificateType.server)
|
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
# check if configured
|
# check if configured
|
||||||
|
|
@ -316,7 +301,7 @@ if __name__ == "__main__":
|
||||||
Connection.connect(current_config.db.uri)
|
Connection.connect(current_config.db.uri)
|
||||||
|
|
||||||
if (device := Device.get(1)) is not None:
|
if (device := Device.get(1)) is not None:
|
||||||
client = easy_rsa.issue(
|
client = EASYRSA.issue(
|
||||||
dn=DistinguishedName.build(device)
|
dn=DistinguishedName.build(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,12 +12,10 @@ If run directly, uses `uvicorn` to run the app.
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from .config import Config, Settings
|
from .config import SETTINGS, Config
|
||||||
from .db import Connection, User
|
from .db import Connection, User
|
||||||
from .routers import main_router
|
from .routers import main_router
|
||||||
|
|
||||||
settings = Settings._
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="kiwi-vpn API",
|
title="kiwi-vpn API",
|
||||||
description="This API enables the `kiwi-vpn` service.",
|
description="This API enables the `kiwi-vpn` service.",
|
||||||
|
|
@ -29,9 +27,9 @@ app = FastAPI(
|
||||||
"name": "MIT License",
|
"name": "MIT License",
|
||||||
"url": "https://opensource.org/licenses/mit-license.php",
|
"url": "https://opensource.org/licenses/mit-license.php",
|
||||||
},
|
},
|
||||||
openapi_url=settings.openapi_url,
|
openapi_url=SETTINGS.openapi_url,
|
||||||
docs_url=settings.docs_url if not settings.production_mode else None,
|
docs_url=SETTINGS.docs_url if not SETTINGS.production_mode else None,
|
||||||
redoc_url=settings.redoc_url if not settings.production_mode else None,
|
redoc_url=SETTINGS.redoc_url if not SETTINGS.production_mode else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(main_router)
|
app.include_router(main_router)
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ This file: Main API router definition.
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
from ..config import Settings
|
from ..config import SETTINGS
|
||||||
from . import admin, device, service, user
|
from . import admin, device, service, user
|
||||||
|
|
||||||
main_router = APIRouter(prefix=f"/{Settings._.api_v1_prefix}")
|
main_router = APIRouter(prefix=f"/{SETTINGS.api_v1_prefix}")
|
||||||
|
|
||||||
main_router.include_router(admin.router)
|
main_router.include_router(admin.router)
|
||||||
main_router.include_router(service.router)
|
main_router.include_router(service.router)
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ Common dependencies for routers.
|
||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
from ..config import Config, Settings
|
from ..config import SETTINGS, Config
|
||||||
from ..db import Device, User
|
from ..db import Device, User
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
|
tokenUrl=f"{SETTINGS.api_v1_prefix}/user/authenticate"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
from ..db import Device, DeviceCreate, DeviceRead, User
|
from ..db import Device, DeviceCreate, DeviceRead, User
|
||||||
from ..easyrsa import DistinguishedName, EasyRSA
|
from ..easyrsa import EASYRSA, DistinguishedName
|
||||||
from ._common import (Responses, get_current_user, get_device_by_id,
|
from ._common import (Responses, get_current_user, get_device_by_id,
|
||||||
get_user_by_name)
|
get_user_by_name)
|
||||||
|
|
||||||
|
|
@ -90,19 +90,19 @@ async def remove_device(
|
||||||
},
|
},
|
||||||
response_model=DeviceRead,
|
response_model=DeviceRead,
|
||||||
)
|
)
|
||||||
async def request_certificate(
|
async def request_certificate_issuance(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
device: Device = Depends(get_device_by_id),
|
device: Device = Depends(get_device_by_id),
|
||||||
) -> Device:
|
) -> Device:
|
||||||
"""
|
"""
|
||||||
POST ./{device_id}/issue: Request certificate for a device.
|
POST ./{device_id}/issue: Request certificate issuance for a device.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# check permission
|
# check permission
|
||||||
if not current_user.can_edit(device):
|
if not current_user.can_edit(device):
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
# cannot request for a newly created device
|
# can only request for a newly created device
|
||||||
if device.approved is not None:
|
if device.approved is not None:
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||||
|
|
||||||
|
|
@ -111,7 +111,7 @@ async def request_certificate(
|
||||||
|
|
||||||
if device.approved:
|
if device.approved:
|
||||||
# issue the certificate immediately
|
# issue the certificate immediately
|
||||||
if (certificate := EasyRSA._.issue(
|
if (certificate := EASYRSA.issue(
|
||||||
dn=DistinguishedName.build(device)
|
dn=DistinguishedName.build(device)
|
||||||
)) is not None:
|
)) is not None:
|
||||||
device.expiry = certificate.not_valid_after
|
device.expiry = certificate.not_valid_after
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue