Compare commits

...

5 commits

9 changed files with 296 additions and 126 deletions

View file

@ -1,10 +1,20 @@
"""
Configuration definition.
Converts per-run (environment) variables and config files into the
"python world" using `pydantic`.
Pydantic models might have convenience methods attached.
"""
from __future__ import annotations
import functools
import json
from datetime import datetime, timedelta
from enum import Enum
from secrets import token_hex
from pathlib import Path
from secrets import token_urlsafe
from jose import JWTError, jwt
from jose.constants import ALGORITHMS
@ -15,8 +25,12 @@ from sqlalchemy.engine import Engine
class Settings(BaseSettings):
"""
Per-run settings
"""
production_mode: bool = False
config_file: str = "tmp/config.json"
data_dir: Path = Path("./tmp")
openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc"
@ -26,24 +40,40 @@ class Settings(BaseSettings):
def get() -> Settings:
return Settings()
@property
def config_file(self) -> Path:
return self.data_dir.joinpath("config.json")
class DBType(Enum):
"""
Supported database types
"""
sqlite = "sqlite"
mysql = "mysql"
class DBConfig(BaseModel):
"""
Database connection configuration
"""
type: DBType = DBType.sqlite
user: str | None = None
password: str | None = None
host: str | None = None
database: str | None = "./tmp/vpn.db"
database: str | None = Settings.get().data_dir.joinpath("vpn.db")
mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"]
@property
async def db_engine(self) -> Engine:
"""
Construct an SQLAlchemy engine
"""
if self.type is DBType.sqlite:
# SQLite backend
return create_engine(
@ -67,6 +97,10 @@ class DBConfig(BaseModel):
class JWTConfig(BaseModel):
"""
Configuration for JSON Web Tokens
"""
secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30
@ -74,8 +108,12 @@ class JWTConfig(BaseModel):
@validator("secret")
@classmethod
def ensure_secret(cls, value: str | None) -> str:
"""
Generate a per-run secret if `None` was loaded from the config file
"""
if value is None:
return token_hex(32)
return token_urlsafe(128)
return value
@ -84,6 +122,10 @@ class JWTConfig(BaseModel):
username: str,
expiry_minutes: int | None = None,
) -> str:
"""
Build and sign a JSON Web Token
"""
if expiry_minutes is None:
expiry_minutes = self.expiry_minutes
@ -100,6 +142,10 @@ class JWTConfig(BaseModel):
self,
token: str,
) -> str | None:
"""
Verify a JSON Web Token, then extract the username
"""
# decode JWT token
try:
payload = jwt.decode(
@ -128,6 +174,10 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel):
"""
Configuration for hash algorithms
"""
schemes: list[str] = ["bcrypt"]
@property
@ -139,12 +189,20 @@ class CryptoConfig(BaseModel):
class Config(BaseModel):
"""
Configuration for `kiwi-vpn-api`
"""
db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod
async def load() -> Config | None:
"""
Load configuration from config file
"""
try:
with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file))
@ -153,5 +211,9 @@ class Config(BaseModel):
return None
async def save(self) -> None:
"""
Save configuration to config file
"""
with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2))

View file

@ -1,3 +1,7 @@
"""
Utilities for handling SQLAlchemy database connections.
"""
from typing import Generator
from sqlalchemy.engine import Engine
@ -7,6 +11,10 @@ from .models import ORMBaseModel
class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session
def __init__(self, session: Session) -> None:
@ -20,11 +28,19 @@ class SessionManager:
class Connection:
"""
Namespace for the database connection.
"""
engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod
def connect(cls, engine: Engine) -> None:
"""
Connect ORM to a database engine.
"""
cls.engine = engine
cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
@ -33,6 +49,10 @@ class Connection:
@classmethod
def use(cls) -> SessionManager | None:
"""
Create an ORM session using a context manager.
"""
if cls.session_local is None:
return None
@ -40,6 +60,10 @@ class Connection:
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None:
yield None

View file

@ -1,3 +1,7 @@
"""
SQLAlchemy representation of database contents.
"""
from __future__ import annotations
import datetime
@ -21,6 +25,10 @@ class User(ORMBaseModel):
@classmethod
def load(cls, db: Session, name: str) -> User | None:
"""
Load user from database by name.
"""
return (db
.query(User)
.filter(User.name == name)

View file

@ -1,7 +1,13 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any
from passlib.context import CryptContext
from pydantic import BaseModel, validator
@ -10,6 +16,10 @@ from sqlalchemy.orm import Session
from . import models
##########
# table: certificates
##########
class CertificateBase(BaseModel):
expiry: datetime
@ -26,24 +36,36 @@ class Certificate(CertificateBase):
class Config:
orm_mode = True
##########
# table: user_capabilities
##########
class UserCapability(Enum):
admin = "admin"
@classmethod
def from_value(cls, value) -> UserCapability:
"""
Create UserCapability from various formats
"""
if isinstance(value, cls):
# use simple value
# value is already a UserCapability, use that
return value
elif isinstance(value, models.UserCapability):
# create from db
# create from db format
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
##########
# table: users
##########
class UserBase(BaseModel):
name: str
@ -62,10 +84,11 @@ class User(UserBase):
@validator("capabilities", pre=True)
@classmethod
def unify_capabilities(
cls,
value: list[models.UserCapability | UserCapability | str]
) -> list[UserCapability]:
def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
"""
Import the capabilities from various formats
"""
return [
UserCapability.from_value(capability)
for capability in value
@ -77,21 +100,29 @@ class User(UserBase):
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
if (db_user := models.User.load(db, name)) is None:
return None
return cls.from_orm(db_user)
@classmethod
def login(
def authenticate(
cls,
db: Session,
name: str,
password: str,
crypt_context: CryptContext,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
if (db_user := models.User.load(db, name)) is None:
# inexistent user, fake doing password verification
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return None
@ -108,6 +139,10 @@ class User(UserBase):
user: UserCreate,
crypt_context: CryptContext,
) -> User | None:
"""
Create a new user in the database.
"""
try:
user = models.User(
name=user.name,
@ -122,6 +157,7 @@ class User(UserBase):
return cls.from_orm(user)
except IntegrityError:
# user already existed
pass
def add_capabilities(
@ -129,16 +165,23 @@ class User(UserBase):
db: Session,
capabilities: list[UserCapability],
) -> None:
"""
Add some capabilities to this user.
"""
for capability in capabilities:
if capability not in self.capabilities:
cap = models.UserCapability(
db.add(models.UserCapability(
user_name=self.name,
capability=capability.value,
)
db.add(cap)
))
db.commit()
##########
# table: distinguished_names
##########
class DistinguishedNameBase(BaseModel):
cn_only: bool

View file

@ -1,50 +1,62 @@
#!/usr/bin/env python3
"""
Main executable of `kiwi-vpn-api`.
Creates the main `FastAPI` app, mounts endpoints and connects to the
configured database.
If run directly, uses `uvicorn` to run the app.
"""
import uvicorn
from fastapi import FastAPI
from .config import Config, Settings
from .db import Connection, schemas
from .db import Connection
from .db.schemas import User
from .routers import admin, user
settings = Settings.get()
api = FastAPI(
title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.",
contact={
"name": "Jörn-Michael Miehe",
"email": "40151420+ldericher@users.noreply.github.com",
},
license_info={
"name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php",
},
openapi_url=settings.openapi_url,
docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None,
)
app = FastAPI()
app.mount("/api", api)
api.include_router(admin.router)
api.include_router(user.router)
@app.on_event("startup")
async def on_startup() -> None:
# check if configured
if (current_config := await Config.load()) is not None:
# connect to database
Connection.connect(await current_config.db.db_engine)
# some testing
with Connection.use() as db:
print(schemas.User.from_db(db, "admin"))
print(schemas.User.from_db(db, "nonexistent"))
print(User.from_db(db, "admin"))
print(User.from_db(db, "nonexistent"))
def main() -> None:
settings = Settings.get()
api = FastAPI(
title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.",
contact={
"name": "Jörn-Michael Miehe",
"email": "40151420+ldericher@users.noreply.github.com",
},
license_info={
"name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php",
},
openapi_url=settings.openapi_url,
docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None,
)
api.include_router(admin.router)
api.include_router(user.router)
app.mount("/api", api)
uvicorn.run(
"kiwi_vpn_api.main:app",
host="0.0.0.0",

View file

@ -1,43 +0,0 @@
# Startup
if config file present:
- load config file
- connect to DB
- mount all routers
else:
- mount admin router
# PUT admin/config
if config file present:
- if user is admin:
- overwrite config
- reload config, reconnect to DB
else:
- overwrite config
- reload config, connect to DB
- mount all routers
# POST admin/user
if no config file present:
- error
elif user table is empty:
- create new user
- give "admin" cap to new user
else:
- if user is admin:
- create new user
...

View file

@ -1,35 +1,46 @@
"""
Common dependencies for routers.
"""
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from ..config import Config
from ..db import Connection, schemas
from ..db import Connection
from ..db.schemas import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/auth")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
# just a namespace
class Responses:
ok = {
"""
Just a namespace.
Describes API response status codes.
"""
OK = {
"content": None,
}
installed = {
INSTALLED = {
"description": "kiwi-vpn already installed",
"content": None,
}
not_installed = {
NOT_INSTALLED = {
"description": "kiwi-vpn not installed",
"content": None,
}
needs_user = {
NEEDS_USER = {
"description": "Must be logged in",
"content": None,
}
needs_admin = {
NEEDS_ADMIN = {
"description": "Must be admin",
"content": None,
}
entry_exists = {
ENTRY_EXISTS = {
"description": "Entry exists in database",
"content": None,
}
@ -40,10 +51,15 @@ async def get_current_user(
db: Session | None = Depends(Connection.get),
current_config: Config | None = Depends(Config.load),
):
"""
Get the currently logged-in user from the database.
"""
# can't connect to an unconfigured database
if current_config is None:
return None
username = await current_config.jwt.decode_token(token)
user = schemas.User.from_db(db, username)
user = User.from_db(db, username)
return user

View file

@ -1,8 +1,14 @@
"""
/admin endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from ..config import Config
from ..db import Connection, schemas
from . import _deps
from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_user
router = APIRouter(prefix="/admin")
@ -10,54 +16,66 @@ router = APIRouter(prefix="/admin")
@router.put(
"/install",
responses={
status.HTTP_200_OK: _deps.Responses.ok,
status.HTTP_400_BAD_REQUEST: _deps.Responses.installed,
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.INSTALLED,
},
)
async def install(
config: Config,
admin_user: schemas.UserCreate,
admin_user: UserCreate,
current_config: Config | None = Depends(Config.load),
):
"""
PUT ./install: Install `kiwi-vpn`.
"""
# fail if already installed
if current_config is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# create config file, connect to database
await config.save()
Connection.connect(await config.db.db_engine)
# create an administrative user
with Connection.use() as db:
admin_user = schemas.User.create(
User.create(
db=db,
user=admin_user,
crypt_context=await config.crypto.crypt_context,
)
admin_user.add_capabilities(
).add_capabilities(
db=db,
capabilities=[schemas.UserCapability.admin],
capabilities=[UserCapability.admin],
)
@router.put(
"/config",
responses={
status.HTTP_200_OK: _deps.Responses.ok,
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed,
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user,
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin,
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
},
)
async def set_config(
new_config: Config,
current_config: Config | None = Depends(Config.load),
current_user: schemas.User | None = Depends(_deps.get_current_user),
current_user: User | None = Depends(get_current_user),
):
"""
PUT ./config: Edit `kiwi-vpn` main config.
"""
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None
or schemas.UserCapability.admin not in current_user.capabilities):
or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database
await new_config.save()
Connection.connect(await new_config.db.db_engine)

View file

@ -1,36 +1,52 @@
"""
/user endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import Config
from ..db import Connection, schemas
from . import _deps
from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_user
router = APIRouter(prefix="/user")
class Token(BaseModel):
"""
Response model for issuing tokens.
"""
access_token: str
token_type: str
@router.post("/auth", response_model=Token)
@router.post("/authenticate", response_model=Token)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
current_config: Config | None = Depends(Config.load),
db: Session | None = Depends(Connection.get),
):
"""
POST ./authenticate: Authenticate a user. Issues a bearer token.
"""
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
user = schemas.User.login(
# try logging in
user = User.authenticate(
db=db,
name=form_data.username,
password=form_data.password,
crypt_context=await current_config.crypto.crypt_context,
)
# authentication failed
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -38,48 +54,62 @@ async def login(
headers={"WWW-Authenticate": "Bearer"},
)
# authentication succeeded
access_token = await current_config.jwt.create_token(user.name)
return {"access_token": access_token, "token_type": "bearer"}
@router.get("/current", response_model=schemas.User)
@router.get("/current", response_model=User)
async def get_current_user(
current_user: schemas.User | None = Depends(_deps.get_current_user),
current_user: User | None = Depends(get_current_user),
):
"""
GET ./current: Respond with the currently logged-in user.
"""
return current_user
@router.post(
"/new",
responses={
status.HTTP_200_OK: _deps.Responses.ok,
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed,
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user,
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin,
status.HTTP_409_CONFLICT: _deps.Responses.entry_exists,
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
response_model=schemas.User,
response_model=User,
)
async def add_user(
user: schemas.UserCreate,
user: UserCreate,
current_config: Config | None = Depends(Config.load),
current_user: schemas.User | None = Depends(_deps.get_current_user),
current_user: User | None = Depends(get_current_user),
db: Session | None = Depends(Connection.get),
):
"""
POST ./new: Create a new user in the database.
"""
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None
or schemas.UserCapability.admin not in current_user.capabilities):
or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
new_user = schemas.User.create(
# actually create the new user
new_user = User.create(
db=db,
user=user,
crypt_context=await current_config.crypto.crypt_context,
)
# fail if creation was unsuccessful
if new_user is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# return the created user on success
return new_user