documentation & some minor refactoring

This commit is contained in:
Jörn-Michael Miehe 2022-03-20 03:45:40 +00:00
parent bf8cb86cee
commit e49a993676
9 changed files with 270 additions and 105 deletions

View file

@ -1,3 +1,12 @@
"""
Configuration definition.
Converts per-run (environment) variables and config files into the
"python world" using `pydantic`.
Pydantic models might have convenience methods attached.
"""
from __future__ import annotations from __future__ import annotations
import functools import functools
@ -16,9 +25,12 @@ from sqlalchemy.engine import Engine
class Settings(BaseSettings): class Settings(BaseSettings):
"""
Per-run settings
"""
production_mode: bool = False production_mode: bool = False
data_dir: Path = Path("./tmp") data_dir: Path = Path("./tmp")
config_file: Path = Path("tmp/config.json")
openapi_url: str = "/openapi.json" openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs" docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc" redoc_url: str | None = "/redoc"
@ -28,13 +40,25 @@ class Settings(BaseSettings):
def get() -> Settings: def get() -> Settings:
return Settings() return Settings()
@property
def config_file(self) -> Path:
return self.data_dir.joinpath("config.json")
class DBType(Enum): class DBType(Enum):
"""
Supported database types
"""
sqlite = "sqlite" sqlite = "sqlite"
mysql = "mysql" mysql = "mysql"
class DBConfig(BaseModel): class DBConfig(BaseModel):
"""
Database connection configuration
"""
type: DBType = DBType.sqlite type: DBType = DBType.sqlite
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
@ -46,6 +70,10 @@ class DBConfig(BaseModel):
@property @property
async def db_engine(self) -> Engine: async def db_engine(self) -> Engine:
"""
Construct an SQLAlchemy engine
"""
if self.type is DBType.sqlite: if self.type is DBType.sqlite:
# SQLite backend # SQLite backend
return create_engine( return create_engine(
@ -69,6 +97,10 @@ class DBConfig(BaseModel):
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
"""
Configuration for JSON Web Tokens
"""
secret: str | None = None secret: str | None = None
hash_algorithm: str = ALGORITHMS.HS256 hash_algorithm: str = ALGORITHMS.HS256
expiry_minutes: int = 30 expiry_minutes: int = 30
@ -76,6 +108,10 @@ class JWTConfig(BaseModel):
@validator("secret") @validator("secret")
@classmethod @classmethod
def ensure_secret(cls, value: str | None) -> str: def ensure_secret(cls, value: str | None) -> str:
"""
Generate a per-run secret if `None` was loaded from the config file
"""
if value is None: if value is None:
return token_urlsafe(128) return token_urlsafe(128)
@ -86,6 +122,10 @@ class JWTConfig(BaseModel):
username: str, username: str,
expiry_minutes: int | None = None, expiry_minutes: int | None = None,
) -> str: ) -> str:
"""
Build and sign a JSON Web Token
"""
if expiry_minutes is None: if expiry_minutes is None:
expiry_minutes = self.expiry_minutes expiry_minutes = self.expiry_minutes
@ -102,6 +142,10 @@ class JWTConfig(BaseModel):
self, self,
token: str, token: str,
) -> str | None: ) -> str | None:
"""
Verify a JSON Web Token, then extract the username
"""
# decode JWT token # decode JWT token
try: try:
payload = jwt.decode( payload = jwt.decode(
@ -130,6 +174,10 @@ class JWTConfig(BaseModel):
class CryptoConfig(BaseModel): class CryptoConfig(BaseModel):
"""
Configuration for hash algorithms
"""
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]
@property @property
@ -141,12 +189,20 @@ class CryptoConfig(BaseModel):
class Config(BaseModel): class Config(BaseModel):
"""
Configuration for `kiwi-vpn-api`
"""
db: DBConfig = Field(default_factory=DBConfig) db: DBConfig = Field(default_factory=DBConfig)
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod @staticmethod
async def load() -> Config | None: async def load() -> Config | None:
"""
Load configuration from config file
"""
try: try:
with open(Settings.get().config_file, "r") as config_file: with open(Settings.get().config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file)) return Config.parse_obj(json.load(config_file))
@ -155,5 +211,9 @@ class Config(BaseModel):
return None return None
async def save(self) -> None: async def save(self) -> None:
"""
Save configuration to config file
"""
with open(Settings.get().config_file, "w") as config_file: with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2)) config_file.write(self.json(indent=2))

View file

@ -1,3 +1,7 @@
"""
Utilities for handling SQLAlchemy database connections.
"""
from typing import Generator from typing import Generator
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
@ -7,6 +11,10 @@ from .models import ORMBaseModel
class SessionManager: class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session __session: Session
def __init__(self, session: Session) -> None: def __init__(self, session: Session) -> None:
@ -20,11 +28,19 @@ class SessionManager:
class Connection: class Connection:
"""
Namespace for the database connection.
"""
engine: Engine | None = None engine: Engine | None = None
session_local: sessionmaker | None = None session_local: sessionmaker | None = None
@classmethod @classmethod
def connect(cls, engine: Engine) -> None: def connect(cls, engine: Engine) -> None:
"""
Connect ORM to a database engine.
"""
cls.engine = engine cls.engine = engine
cls.session_local = sessionmaker( cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine, autocommit=False, autoflush=False, bind=engine,
@ -33,6 +49,10 @@ class Connection:
@classmethod @classmethod
def use(cls) -> SessionManager | None: def use(cls) -> SessionManager | None:
"""
Create an ORM session using a context manager.
"""
if cls.session_local is None: if cls.session_local is None:
return None return None
@ -40,6 +60,10 @@ class Connection:
@classmethod @classmethod
async def get(cls) -> Generator[Session | None, None, None]: async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None: if cls.session_local is None:
yield None yield None

View file

@ -1,3 +1,7 @@
"""
SQLAlchemy representation of database contents.
"""
from __future__ import annotations from __future__ import annotations
import datetime import datetime
@ -21,6 +25,10 @@ class User(ORMBaseModel):
@classmethod @classmethod
def load(cls, db: Session, name: str) -> User | None: def load(cls, db: Session, name: str) -> User | None:
"""
Load user from database by name.
"""
return (db return (db
.query(User) .query(User)
.filter(User.name == name) .filter(User.name == name)

View file

@ -1,7 +1,13 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
@ -10,6 +16,10 @@ from sqlalchemy.orm import Session
from . import models from . import models
##########
# table: certificates
##########
class CertificateBase(BaseModel): class CertificateBase(BaseModel):
expiry: datetime expiry: datetime
@ -26,24 +36,36 @@ class Certificate(CertificateBase):
class Config: class Config:
orm_mode = True orm_mode = True
##########
# table: user_capabilities
##########
class UserCapability(Enum): class UserCapability(Enum):
admin = "admin" admin = "admin"
@classmethod @classmethod
def from_value(cls, value) -> UserCapability: def from_value(cls, value) -> UserCapability:
"""
Create UserCapability from various formats
"""
if isinstance(value, cls): if isinstance(value, cls):
# use simple value # value is already a UserCapability, use that
return value return value
elif isinstance(value, models.UserCapability): elif isinstance(value, models.UserCapability):
# create from db # create from db format
return cls(value.capability) return cls(value.capability)
else: else:
# create from string representation # create from string representation
return cls(str(value)) return cls(str(value))
##########
# table: users
##########
class UserBase(BaseModel): class UserBase(BaseModel):
name: str name: str
@ -62,10 +84,11 @@ class User(UserBase):
@validator("capabilities", pre=True) @validator("capabilities", pre=True)
@classmethod @classmethod
def unify_capabilities( def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
cls, """
value: list[models.UserCapability | UserCapability | str] Import the capabilities from various formats
) -> list[UserCapability]: """
return [ return [
UserCapability.from_value(capability) UserCapability.from_value(capability)
for capability in value for capability in value
@ -77,21 +100,29 @@ class User(UserBase):
db: Session, db: Session,
name: str, name: str,
) -> User | None: ) -> User | None:
"""
Load user from database by name.
"""
if (db_user := models.User.load(db, name)) is None: if (db_user := models.User.load(db, name)) is None:
return None return None
return cls.from_orm(db_user) return cls.from_orm(db_user)
@classmethod @classmethod
def login( def authenticate(
cls, cls,
db: Session, db: Session,
name: str, name: str,
password: str, password: str,
crypt_context: CryptContext, crypt_context: CryptContext,
) -> User | None: ) -> User | None:
"""
Authenticate with name/password against users in database.
"""
if (db_user := models.User.load(db, name)) is None: if (db_user := models.User.load(db, name)) is None:
# inexistent user, fake doing password verification # nonexistent user, fake doing password verification
crypt_context.dummy_verify() crypt_context.dummy_verify()
return None return None
@ -108,6 +139,10 @@ class User(UserBase):
user: UserCreate, user: UserCreate,
crypt_context: CryptContext, crypt_context: CryptContext,
) -> User | None: ) -> User | None:
"""
Create a new user in the database.
"""
try: try:
user = models.User( user = models.User(
name=user.name, name=user.name,
@ -122,6 +157,7 @@ class User(UserBase):
return cls.from_orm(user) return cls.from_orm(user)
except IntegrityError: except IntegrityError:
# user already existed
pass pass
def add_capabilities( def add_capabilities(
@ -129,6 +165,10 @@ class User(UserBase):
db: Session, db: Session,
capabilities: list[UserCapability], capabilities: list[UserCapability],
) -> None: ) -> None:
"""
Add some capabilities to this user.
"""
for capability in capabilities: for capability in capabilities:
if capability not in self.capabilities: if capability not in self.capabilities:
db.add(models.UserCapability( db.add(models.UserCapability(
@ -138,6 +178,10 @@ class User(UserBase):
db.commit() db.commit()
##########
# table: distinguished_names
##########
class DistinguishedNameBase(BaseModel): class DistinguishedNameBase(BaseModel):
cn_only: bool cn_only: bool

View file

@ -1,5 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
Main executable of `kiwi-vpn-api`.
Creates the main `FastAPI` app, mounts endpoints and connects to the
configured database.
If run directly, uses `uvicorn` to run the app.
"""
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@ -8,9 +17,25 @@ from .db import Connection
from .db.schemas import User from .db.schemas import User
from .routers import admin, user from .routers import admin, user
settings = Settings.get() app = FastAPI()
@app.on_event("startup")
async def on_startup() -> None:
# check if configured
if (current_config := await Config.load()) is not None:
# connect to database
Connection.connect(await current_config.db.db_engine)
# some testing
with Connection.use() as db:
print(User.from_db(db, "admin"))
print(User.from_db(db, "nonexistent"))
def main() -> None:
settings = Settings.get()
api = FastAPI( api = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",
@ -27,25 +52,11 @@ api = FastAPI(
redoc_url=settings.redoc_url if not settings.production_mode else None, redoc_url=settings.redoc_url if not settings.production_mode else None,
) )
app = FastAPI()
app.mount("/api", api)
api.include_router(admin.router) api.include_router(admin.router)
api.include_router(user.router) api.include_router(user.router)
app.mount("/api", api)
@app.on_event("startup")
async def on_startup() -> None:
if (current_config := await Config.load()) is not None:
Connection.connect(await current_config.db.db_engine)
# some testing
with Connection.use() as db:
print(User.from_db(db, "admin"))
print(User.from_db(db, "nonexistent"))
def main() -> None:
uvicorn.run( uvicorn.run(
"kiwi_vpn_api.main:app", "kiwi_vpn_api.main:app",
host="0.0.0.0", host="0.0.0.0",

View file

@ -1,43 +0,0 @@
# Startup
if config file present:
- load config file
- connect to DB
- mount all routers
else:
- mount admin router
# PUT admin/config
if config file present:
- if user is admin:
- overwrite config
- reload config, reconnect to DB
else:
- overwrite config
- reload config, connect to DB
- mount all routers
# POST admin/user
if no config file present:
- error
elif user table is empty:
- create new user
- give "admin" cap to new user
else:
- if user is admin:
- create new user
...

View file

@ -1,3 +1,8 @@
"""
Common dependencies for routers.
"""
from fastapi import Depends from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -6,31 +11,36 @@ from ..config import Config
from ..db import Connection from ..db import Connection
from ..db.schemas import User from ..db.schemas import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/auth") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
# just a namespace
class Responses: class Responses:
ok = { """
Just a namespace.
Describes API response status codes.
"""
OK = {
"content": None, "content": None,
} }
installed = { INSTALLED = {
"description": "kiwi-vpn already installed", "description": "kiwi-vpn already installed",
"content": None, "content": None,
} }
not_installed = { NOT_INSTALLED = {
"description": "kiwi-vpn not installed", "description": "kiwi-vpn not installed",
"content": None, "content": None,
} }
needs_user = { NEEDS_USER = {
"description": "Must be logged in", "description": "Must be logged in",
"content": None, "content": None,
} }
needs_admin = { NEEDS_ADMIN = {
"description": "Must be admin", "description": "Must be admin",
"content": None, "content": None,
} }
entry_exists = { ENTRY_EXISTS = {
"description": "Entry exists in database", "description": "Entry exists in database",
"content": None, "content": None,
} }
@ -41,6 +51,11 @@ async def get_current_user(
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
"""
Get the currently logged-in user from the database.
"""
# can't connect to an unconfigured database
if current_config is None: if current_config is None:
return None return None

View file

@ -1,9 +1,14 @@
"""
/admin endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate from ..db.schemas import User, UserCapability, UserCreate
from . import _deps from ._common import Responses, get_current_user
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
@ -11,8 +16,8 @@ router = APIRouter(prefix="/admin")
@router.put( @router.put(
"/install", "/install",
responses={ responses={
status.HTTP_200_OK: _deps.Responses.ok, status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: _deps.Responses.installed, status.HTTP_400_BAD_REQUEST: Responses.INSTALLED,
}, },
) )
async def install( async def install(
@ -20,20 +25,25 @@ async def install(
admin_user: UserCreate, admin_user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
"""
PUT ./install: Install `kiwi-vpn`.
"""
# fail if already installed
if current_config is not None: if current_config is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# create config file, connect to database
await config.save() await config.save()
Connection.connect(await config.db.db_engine) Connection.connect(await config.db.db_engine)
# create an administrative user
with Connection.use() as db: with Connection.use() as db:
admin_user = User.create( User.create(
db=db, db=db,
user=admin_user, user=admin_user,
crypt_context=await config.crypto.crypt_context, crypt_context=await config.crypto.crypt_context,
) ).add_capabilities(
admin_user.add_capabilities(
db=db, db=db,
capabilities=[UserCapability.admin], capabilities=[UserCapability.admin],
) )
@ -42,23 +52,30 @@ async def install(
@router.put( @router.put(
"/config", "/config",
responses={ responses={
status.HTTP_200_OK: _deps.Responses.ok, status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed, status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
}, },
) )
async def set_config( async def set_config(
new_config: Config, new_config: Config,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(_deps.get_current_user), current_user: User | None = Depends(get_current_user),
): ):
"""
PUT ./config: Edit `kiwi-vpn` main config.
"""
# fail if not installed
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None if (current_user is None
or UserCapability.admin not in current_user.capabilities): or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database
await new_config.save() await new_config.save()
Connection.connect(await new_config.db.db_engine) Connection.connect(await new_config.db.db_engine)

View file

@ -1,3 +1,7 @@
"""
/user endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
@ -6,32 +10,43 @@ from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate from ..db.schemas import User, UserCapability, UserCreate
from . import _deps from ._common import Responses, get_current_user
router = APIRouter(prefix="/user") router = APIRouter(prefix="/user")
class Token(BaseModel): class Token(BaseModel):
"""
Response model for issuing tokens.
"""
access_token: str access_token: str
token_type: str token_type: str
@router.post("/auth", response_model=Token) @router.post("/authenticate", response_model=Token)
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
): ):
"""
POST ./authenticate: Authenticate a user. Issues a bearer token.
"""
# fail if not installed
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
user = User.login( # try logging in
user = User.authenticate(
db=db, db=db,
name=form_data.username, name=form_data.username,
password=form_data.password, password=form_data.password,
crypt_context=await current_config.crypto.crypt_context, crypt_context=await current_config.crypto.crypt_context,
) )
# authentication failed
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -39,48 +54,62 @@ async def login(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
# authentication succeeded
access_token = await current_config.jwt.create_token(user.name) access_token = await current_config.jwt.create_token(user.name)
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.get("/current", response_model=User) @router.get("/current", response_model=User)
async def get_current_user( async def get_current_user(
current_user: User | None = Depends(_deps.get_current_user), current_user: User | None = Depends(get_current_user),
): ):
"""
GET ./current: Respond with the currently logged-in user.
"""
return current_user return current_user
@router.post( @router.post(
"/new", "/new",
responses={ responses={
status.HTTP_200_OK: _deps.Responses.ok, status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed, status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user, status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin, status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_409_CONFLICT: _deps.Responses.entry_exists, status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
}, },
response_model=User, response_model=User,
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(_deps.get_current_user), current_user: User | None = Depends(get_current_user),
db: Session | None = Depends(Connection.get), db: Session | None = Depends(Connection.get),
): ):
"""
POST ./new: Create a new user in the database.
"""
# fail if not installed
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None if (current_user is None
or UserCapability.admin not in current_user.capabilities): or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# actually create the new user
new_user = User.create( new_user = User.create(
db=db, db=db,
user=user, user=user,
crypt_context=await current_config.crypto.crypt_context, crypt_context=await current_config.crypto.crypt_context,
) )
# fail if creation was unsuccessful
if new_user is None: if new_user is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# return the created user on success
return new_user return new_user