Compare commits
No commits in common. "e49a99367645ad239615b82b65d9cc7f7b43cefd" and "00bdf88b6e234fad2fac38896843013b0e008d06" have entirely different histories.
e49a993676
...
00bdf88b6e
9 changed files with 126 additions and 296 deletions
|
|
@ -1,20 +1,10 @@
|
||||||
"""
|
|
||||||
Configuration definition.
|
|
||||||
|
|
||||||
Converts per-run (environment) variables and config files into the
|
|
||||||
"python world" using `pydantic`.
|
|
||||||
|
|
||||||
Pydantic models might have convenience methods attached.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from secrets import token_hex
|
||||||
from secrets import token_urlsafe
|
|
||||||
|
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from jose.constants import ALGORITHMS
|
from jose.constants import ALGORITHMS
|
||||||
|
|
@ -25,12 +15,8 @@ from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""
|
|
||||||
Per-run settings
|
|
||||||
"""
|
|
||||||
|
|
||||||
production_mode: bool = False
|
production_mode: bool = False
|
||||||
data_dir: Path = Path("./tmp")
|
config_file: str = "tmp/config.json"
|
||||||
openapi_url: str = "/openapi.json"
|
openapi_url: str = "/openapi.json"
|
||||||
docs_url: str | None = "/docs"
|
docs_url: str | None = "/docs"
|
||||||
redoc_url: str | None = "/redoc"
|
redoc_url: str | None = "/redoc"
|
||||||
|
|
@ -40,40 +26,24 @@ class Settings(BaseSettings):
|
||||||
def get() -> Settings:
|
def get() -> Settings:
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
||||||
@property
|
|
||||||
def config_file(self) -> Path:
|
|
||||||
return self.data_dir.joinpath("config.json")
|
|
||||||
|
|
||||||
|
|
||||||
class DBType(Enum):
|
class DBType(Enum):
|
||||||
"""
|
|
||||||
Supported database types
|
|
||||||
"""
|
|
||||||
|
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
mysql = "mysql"
|
mysql = "mysql"
|
||||||
|
|
||||||
|
|
||||||
class DBConfig(BaseModel):
|
class DBConfig(BaseModel):
|
||||||
"""
|
|
||||||
Database connection configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
type: DBType = DBType.sqlite
|
type: DBType = DBType.sqlite
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
password: str | None = None
|
password: str | None = None
|
||||||
host: str | None = None
|
host: str | None = None
|
||||||
database: str | None = Settings.get().data_dir.joinpath("vpn.db")
|
database: str | None = "./tmp/vpn.db"
|
||||||
|
|
||||||
mysql_driver: str = "pymysql"
|
mysql_driver: str = "pymysql"
|
||||||
mysql_args: list[str] = ["charset=utf8mb4"]
|
mysql_args: list[str] = ["charset=utf8mb4"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
async def db_engine(self) -> Engine:
|
async def db_engine(self) -> Engine:
|
||||||
"""
|
|
||||||
Construct an SQLAlchemy engine
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.type is DBType.sqlite:
|
if self.type is DBType.sqlite:
|
||||||
# SQLite backend
|
# SQLite backend
|
||||||
return create_engine(
|
return create_engine(
|
||||||
|
|
@ -97,10 +67,6 @@ class DBConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class JWTConfig(BaseModel):
|
class JWTConfig(BaseModel):
|
||||||
"""
|
|
||||||
Configuration for JSON Web Tokens
|
|
||||||
"""
|
|
||||||
|
|
||||||
secret: str | None = None
|
secret: str | None = None
|
||||||
hash_algorithm: str = ALGORITHMS.HS256
|
hash_algorithm: str = ALGORITHMS.HS256
|
||||||
expiry_minutes: int = 30
|
expiry_minutes: int = 30
|
||||||
|
|
@ -108,12 +74,8 @@ class JWTConfig(BaseModel):
|
||||||
@validator("secret")
|
@validator("secret")
|
||||||
@classmethod
|
@classmethod
|
||||||
def ensure_secret(cls, value: str | None) -> str:
|
def ensure_secret(cls, value: str | None) -> str:
|
||||||
"""
|
|
||||||
Generate a per-run secret if `None` was loaded from the config file
|
|
||||||
"""
|
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
return token_urlsafe(128)
|
return token_hex(32)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
@ -122,10 +84,6 @@ class JWTConfig(BaseModel):
|
||||||
username: str,
|
username: str,
|
||||||
expiry_minutes: int | None = None,
|
expiry_minutes: int | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
|
||||||
Build and sign a JSON Web Token
|
|
||||||
"""
|
|
||||||
|
|
||||||
if expiry_minutes is None:
|
if expiry_minutes is None:
|
||||||
expiry_minutes = self.expiry_minutes
|
expiry_minutes = self.expiry_minutes
|
||||||
|
|
||||||
|
|
@ -142,10 +100,6 @@ class JWTConfig(BaseModel):
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""
|
|
||||||
Verify a JSON Web Token, then extract the username
|
|
||||||
"""
|
|
||||||
|
|
||||||
# decode JWT token
|
# decode JWT token
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
|
|
@ -174,10 +128,6 @@ class JWTConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class CryptoConfig(BaseModel):
|
class CryptoConfig(BaseModel):
|
||||||
"""
|
|
||||||
Configuration for hash algorithms
|
|
||||||
"""
|
|
||||||
|
|
||||||
schemes: list[str] = ["bcrypt"]
|
schemes: list[str] = ["bcrypt"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -189,20 +139,12 @@ class CryptoConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
"""
|
|
||||||
Configuration for `kiwi-vpn-api`
|
|
||||||
"""
|
|
||||||
|
|
||||||
db: DBConfig = Field(default_factory=DBConfig)
|
db: DBConfig = Field(default_factory=DBConfig)
|
||||||
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
jwt: JWTConfig = Field(default_factory=JWTConfig)
|
||||||
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
crypto: CryptoConfig = Field(default_factory=CryptoConfig)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def load() -> Config | None:
|
async def load() -> Config | None:
|
||||||
"""
|
|
||||||
Load configuration from config file
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(Settings.get().config_file, "r") as config_file:
|
with open(Settings.get().config_file, "r") as config_file:
|
||||||
return Config.parse_obj(json.load(config_file))
|
return Config.parse_obj(json.load(config_file))
|
||||||
|
|
@ -211,9 +153,5 @@ class Config(BaseModel):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def save(self) -> None:
|
async def save(self) -> None:
|
||||||
"""
|
|
||||||
Save configuration to config file
|
|
||||||
"""
|
|
||||||
|
|
||||||
with open(Settings.get().config_file, "w") as config_file:
|
with open(Settings.get().config_file, "w") as config_file:
|
||||||
config_file.write(self.json(indent=2))
|
config_file.write(self.json(indent=2))
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,3 @@
|
||||||
"""
|
|
||||||
Utilities for handling SQLAlchemy database connections.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
|
|
@ -11,10 +7,6 @@ from .models import ORMBaseModel
|
||||||
|
|
||||||
|
|
||||||
class SessionManager:
|
class SessionManager:
|
||||||
"""
|
|
||||||
Simple context manager for an ORM session.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__session: Session
|
__session: Session
|
||||||
|
|
||||||
def __init__(self, session: Session) -> None:
|
def __init__(self, session: Session) -> None:
|
||||||
|
|
@ -28,19 +20,11 @@ class SessionManager:
|
||||||
|
|
||||||
|
|
||||||
class Connection:
|
class Connection:
|
||||||
"""
|
|
||||||
Namespace for the database connection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
engine: Engine | None = None
|
engine: Engine | None = None
|
||||||
session_local: sessionmaker | None = None
|
session_local: sessionmaker | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def connect(cls, engine: Engine) -> None:
|
def connect(cls, engine: Engine) -> None:
|
||||||
"""
|
|
||||||
Connect ORM to a database engine.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cls.engine = engine
|
cls.engine = engine
|
||||||
cls.session_local = sessionmaker(
|
cls.session_local = sessionmaker(
|
||||||
autocommit=False, autoflush=False, bind=engine,
|
autocommit=False, autoflush=False, bind=engine,
|
||||||
|
|
@ -49,10 +33,6 @@ class Connection:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use(cls) -> SessionManager | None:
|
def use(cls) -> SessionManager | None:
|
||||||
"""
|
|
||||||
Create an ORM session using a context manager.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if cls.session_local is None:
|
if cls.session_local is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -60,10 +40,6 @@ class Connection:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get(cls) -> Generator[Session | None, None, None]:
|
async def get(cls) -> Generator[Session | None, None, None]:
|
||||||
"""
|
|
||||||
Create an ORM session using a FastAPI compatible async generator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if cls.session_local is None:
|
if cls.session_local is None:
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,3 @@
|
||||||
"""
|
|
||||||
SQLAlchemy representation of database contents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
@ -25,10 +21,6 @@ class User(ORMBaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, db: Session, name: str) -> User | None:
|
def load(cls, db: Session, name: str) -> User | None:
|
||||||
"""
|
|
||||||
Load user from database by name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return (db
|
return (db
|
||||||
.query(User)
|
.query(User)
|
||||||
.filter(User.name == name)
|
.filter(User.name == name)
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,7 @@
|
||||||
"""
|
|
||||||
Pydantic representation of database contents.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
|
@ -16,10 +10,6 @@ from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from . import models
|
from . import models
|
||||||
|
|
||||||
##########
|
|
||||||
# table: certificates
|
|
||||||
##########
|
|
||||||
|
|
||||||
|
|
||||||
class CertificateBase(BaseModel):
|
class CertificateBase(BaseModel):
|
||||||
expiry: datetime
|
expiry: datetime
|
||||||
|
|
@ -36,36 +26,24 @@ class Certificate(CertificateBase):
|
||||||
class Config:
|
class Config:
|
||||||
orm_mode = True
|
orm_mode = True
|
||||||
|
|
||||||
##########
|
|
||||||
# table: user_capabilities
|
|
||||||
##########
|
|
||||||
|
|
||||||
|
|
||||||
class UserCapability(Enum):
|
class UserCapability(Enum):
|
||||||
admin = "admin"
|
admin = "admin"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_value(cls, value) -> UserCapability:
|
def from_value(cls, value) -> UserCapability:
|
||||||
"""
|
|
||||||
Create UserCapability from various formats
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(value, cls):
|
if isinstance(value, cls):
|
||||||
# value is already a UserCapability, use that
|
# use simple value
|
||||||
return value
|
return value
|
||||||
|
|
||||||
elif isinstance(value, models.UserCapability):
|
elif isinstance(value, models.UserCapability):
|
||||||
# create from db format
|
# create from db
|
||||||
return cls(value.capability)
|
return cls(value.capability)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# create from string representation
|
# create from string representation
|
||||||
return cls(str(value))
|
return cls(str(value))
|
||||||
|
|
||||||
##########
|
|
||||||
# table: users
|
|
||||||
##########
|
|
||||||
|
|
||||||
|
|
||||||
class UserBase(BaseModel):
|
class UserBase(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
@ -84,11 +62,10 @@ class User(UserBase):
|
||||||
|
|
||||||
@validator("capabilities", pre=True)
|
@validator("capabilities", pre=True)
|
||||||
@classmethod
|
@classmethod
|
||||||
def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
|
def unify_capabilities(
|
||||||
"""
|
cls,
|
||||||
Import the capabilities from various formats
|
value: list[models.UserCapability | UserCapability | str]
|
||||||
"""
|
) -> list[UserCapability]:
|
||||||
|
|
||||||
return [
|
return [
|
||||||
UserCapability.from_value(capability)
|
UserCapability.from_value(capability)
|
||||||
for capability in value
|
for capability in value
|
||||||
|
|
@ -100,29 +77,21 @@ class User(UserBase):
|
||||||
db: Session,
|
db: Session,
|
||||||
name: str,
|
name: str,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
"""
|
|
||||||
Load user from database by name.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if (db_user := models.User.load(db, name)) is None:
|
if (db_user := models.User.load(db, name)) is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return cls.from_orm(db_user)
|
return cls.from_orm(db_user)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def authenticate(
|
def login(
|
||||||
cls,
|
cls,
|
||||||
db: Session,
|
db: Session,
|
||||||
name: str,
|
name: str,
|
||||||
password: str,
|
password: str,
|
||||||
crypt_context: CryptContext,
|
crypt_context: CryptContext,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
"""
|
|
||||||
Authenticate with name/password against users in database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if (db_user := models.User.load(db, name)) is None:
|
if (db_user := models.User.load(db, name)) is None:
|
||||||
# nonexistent user, fake doing password verification
|
# inexistent user, fake doing password verification
|
||||||
crypt_context.dummy_verify()
|
crypt_context.dummy_verify()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -139,10 +108,6 @@ class User(UserBase):
|
||||||
user: UserCreate,
|
user: UserCreate,
|
||||||
crypt_context: CryptContext,
|
crypt_context: CryptContext,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
"""
|
|
||||||
Create a new user in the database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = models.User(
|
user = models.User(
|
||||||
name=user.name,
|
name=user.name,
|
||||||
|
|
@ -157,7 +122,6 @@ class User(UserBase):
|
||||||
return cls.from_orm(user)
|
return cls.from_orm(user)
|
||||||
|
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
# user already existed
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def add_capabilities(
|
def add_capabilities(
|
||||||
|
|
@ -165,23 +129,16 @@ class User(UserBase):
|
||||||
db: Session,
|
db: Session,
|
||||||
capabilities: list[UserCapability],
|
capabilities: list[UserCapability],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Add some capabilities to this user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
for capability in capabilities:
|
for capability in capabilities:
|
||||||
if capability not in self.capabilities:
|
if capability not in self.capabilities:
|
||||||
db.add(models.UserCapability(
|
cap = models.UserCapability(
|
||||||
user_name=self.name,
|
user_name=self.name,
|
||||||
capability=capability.value,
|
capability=capability.value,
|
||||||
))
|
)
|
||||||
|
db.add(cap)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
##########
|
|
||||||
# table: distinguished_names
|
|
||||||
##########
|
|
||||||
|
|
||||||
|
|
||||||
class DistinguishedNameBase(BaseModel):
|
class DistinguishedNameBase(BaseModel):
|
||||||
cn_only: bool
|
cn_only: bool
|
||||||
|
|
|
||||||
|
|
@ -1,62 +1,50 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
|
||||||
Main executable of `kiwi-vpn-api`.
|
|
||||||
|
|
||||||
Creates the main `FastAPI` app, mounts endpoints and connects to the
|
|
||||||
configured database.
|
|
||||||
|
|
||||||
If run directly, uses `uvicorn` to run the app.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
from .config import Config, Settings
|
from .config import Config, Settings
|
||||||
from .db import Connection
|
from .db import Connection, schemas
|
||||||
from .db.schemas import User
|
|
||||||
from .routers import admin, user
|
from .routers import admin, user
|
||||||
|
|
||||||
|
settings = Settings.get()
|
||||||
|
|
||||||
|
|
||||||
|
api = FastAPI(
|
||||||
|
title="kiwi-vpn API",
|
||||||
|
description="This API enables the `kiwi-vpn` service.",
|
||||||
|
contact={
|
||||||
|
"name": "Jörn-Michael Miehe",
|
||||||
|
"email": "40151420+ldericher@users.noreply.github.com",
|
||||||
|
},
|
||||||
|
license_info={
|
||||||
|
"name": "MIT License",
|
||||||
|
"url": "https://opensource.org/licenses/mit-license.php",
|
||||||
|
},
|
||||||
|
openapi_url=settings.openapi_url,
|
||||||
|
docs_url=settings.docs_url if not settings.production_mode else None,
|
||||||
|
redoc_url=settings.redoc_url if not settings.production_mode else None,
|
||||||
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
app.mount("/api", api)
|
||||||
|
|
||||||
|
api.include_router(admin.router)
|
||||||
|
api.include_router(user.router)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def on_startup() -> None:
|
async def on_startup() -> None:
|
||||||
# check if configured
|
|
||||||
if (current_config := await Config.load()) is not None:
|
if (current_config := await Config.load()) is not None:
|
||||||
# connect to database
|
|
||||||
Connection.connect(await current_config.db.db_engine)
|
Connection.connect(await current_config.db.db_engine)
|
||||||
|
|
||||||
# some testing
|
# some testing
|
||||||
with Connection.use() as db:
|
with Connection.use() as db:
|
||||||
print(User.from_db(db, "admin"))
|
print(schemas.User.from_db(db, "admin"))
|
||||||
print(User.from_db(db, "nonexistent"))
|
print(schemas.User.from_db(db, "nonexistent"))
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
settings = Settings.get()
|
|
||||||
|
|
||||||
api = FastAPI(
|
|
||||||
title="kiwi-vpn API",
|
|
||||||
description="This API enables the `kiwi-vpn` service.",
|
|
||||||
contact={
|
|
||||||
"name": "Jörn-Michael Miehe",
|
|
||||||
"email": "40151420+ldericher@users.noreply.github.com",
|
|
||||||
},
|
|
||||||
license_info={
|
|
||||||
"name": "MIT License",
|
|
||||||
"url": "https://opensource.org/licenses/mit-license.php",
|
|
||||||
},
|
|
||||||
openapi_url=settings.openapi_url,
|
|
||||||
docs_url=settings.docs_url if not settings.production_mode else None,
|
|
||||||
redoc_url=settings.redoc_url if not settings.production_mode else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
api.include_router(admin.router)
|
|
||||||
api.include_router(user.router)
|
|
||||||
|
|
||||||
app.mount("/api", api)
|
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"kiwi_vpn_api.main:app",
|
"kiwi_vpn_api.main:app",
|
||||||
host="0.0.0.0",
|
host="0.0.0.0",
|
||||||
|
|
|
||||||
43
api/kiwi_vpn_api/plan.md
Normal file
43
api/kiwi_vpn_api/plan.md
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Startup
|
||||||
|
|
||||||
|
if config file present:
|
||||||
|
|
||||||
|
- load config file
|
||||||
|
- connect to DB
|
||||||
|
- mount all routers
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- mount admin router
|
||||||
|
|
||||||
|
# PUT admin/config
|
||||||
|
|
||||||
|
if config file present:
|
||||||
|
|
||||||
|
- if user is admin:
|
||||||
|
- overwrite config
|
||||||
|
- reload config, reconnect to DB
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- overwrite config
|
||||||
|
- reload config, connect to DB
|
||||||
|
- mount all routers
|
||||||
|
|
||||||
|
# POST admin/user
|
||||||
|
|
||||||
|
if no config file present:
|
||||||
|
|
||||||
|
- error
|
||||||
|
|
||||||
|
elif user table is empty:
|
||||||
|
|
||||||
|
- create new user
|
||||||
|
- give "admin" cap to new user
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
- if user is admin:
|
||||||
|
- create new user
|
||||||
|
|
||||||
|
...
|
||||||
|
|
@ -1,46 +1,35 @@
|
||||||
"""
|
|
||||||
Common dependencies for routers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..db import Connection
|
from ..db import Connection, schemas
|
||||||
from ..db.schemas import User
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/auth")
|
||||||
|
|
||||||
|
|
||||||
|
# just a namespace
|
||||||
class Responses:
|
class Responses:
|
||||||
"""
|
ok = {
|
||||||
Just a namespace.
|
|
||||||
|
|
||||||
Describes API response status codes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
OK = {
|
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
INSTALLED = {
|
installed = {
|
||||||
"description": "kiwi-vpn already installed",
|
"description": "kiwi-vpn already installed",
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
NOT_INSTALLED = {
|
not_installed = {
|
||||||
"description": "kiwi-vpn not installed",
|
"description": "kiwi-vpn not installed",
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
NEEDS_USER = {
|
needs_user = {
|
||||||
"description": "Must be logged in",
|
"description": "Must be logged in",
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
NEEDS_ADMIN = {
|
needs_admin = {
|
||||||
"description": "Must be admin",
|
"description": "Must be admin",
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
ENTRY_EXISTS = {
|
entry_exists = {
|
||||||
"description": "Entry exists in database",
|
"description": "Entry exists in database",
|
||||||
"content": None,
|
"content": None,
|
||||||
}
|
}
|
||||||
|
|
@ -51,15 +40,10 @@ async def get_current_user(
|
||||||
db: Session | None = Depends(Connection.get),
|
db: Session | None = Depends(Connection.get),
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
Get the currently logged-in user from the database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# can't connect to an unconfigured database
|
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
username = await current_config.jwt.decode_token(token)
|
username = await current_config.jwt.decode_token(token)
|
||||||
user = User.from_db(db, username)
|
user = schemas.User.from_db(db, username)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
@ -1,14 +1,8 @@
|
||||||
"""
|
|
||||||
/admin endpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..db import Connection
|
from ..db import Connection, schemas
|
||||||
from ..db.schemas import User, UserCapability, UserCreate
|
from . import _deps
|
||||||
from ._common import Responses, get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/admin")
|
router = APIRouter(prefix="/admin")
|
||||||
|
|
||||||
|
|
@ -16,66 +10,54 @@ router = APIRouter(prefix="/admin")
|
||||||
@router.put(
|
@router.put(
|
||||||
"/install",
|
"/install",
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: Responses.OK,
|
status.HTTP_200_OK: _deps.Responses.ok,
|
||||||
status.HTTP_400_BAD_REQUEST: Responses.INSTALLED,
|
status.HTTP_400_BAD_REQUEST: _deps.Responses.installed,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def install(
|
async def install(
|
||||||
config: Config,
|
config: Config,
|
||||||
admin_user: UserCreate,
|
admin_user: schemas.UserCreate,
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
PUT ./install: Install `kiwi-vpn`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# fail if already installed
|
|
||||||
if current_config is not None:
|
if current_config is not None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# create config file, connect to database
|
|
||||||
await config.save()
|
await config.save()
|
||||||
Connection.connect(await config.db.db_engine)
|
Connection.connect(await config.db.db_engine)
|
||||||
|
|
||||||
# create an administrative user
|
|
||||||
with Connection.use() as db:
|
with Connection.use() as db:
|
||||||
User.create(
|
admin_user = schemas.User.create(
|
||||||
db=db,
|
db=db,
|
||||||
user=admin_user,
|
user=admin_user,
|
||||||
crypt_context=await config.crypto.crypt_context,
|
crypt_context=await config.crypto.crypt_context,
|
||||||
).add_capabilities(
|
)
|
||||||
|
|
||||||
|
admin_user.add_capabilities(
|
||||||
db=db,
|
db=db,
|
||||||
capabilities=[UserCapability.admin],
|
capabilities=[schemas.UserCapability.admin],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.put(
|
@router.put(
|
||||||
"/config",
|
"/config",
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: Responses.OK,
|
status.HTTP_200_OK: _deps.Responses.ok,
|
||||||
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
|
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed,
|
||||||
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
|
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user,
|
||||||
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
|
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def set_config(
|
async def set_config(
|
||||||
new_config: Config,
|
new_config: Config,
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
current_user: User | None = Depends(get_current_user),
|
current_user: schemas.User | None = Depends(_deps.get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
PUT ./config: Edit `kiwi-vpn` main config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# fail if not installed
|
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# fail if not requested by an admin
|
|
||||||
if (current_user is None
|
if (current_user is None
|
||||||
or UserCapability.admin not in current_user.capabilities):
|
or schemas.UserCapability.admin not in current_user.capabilities):
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
# update config file, reconnect to database
|
|
||||||
await new_config.save()
|
await new_config.save()
|
||||||
Connection.connect(await new_config.db.db_engine)
|
Connection.connect(await new_config.db.db_engine)
|
||||||
|
|
|
||||||
|
|
@ -1,52 +1,36 @@
|
||||||
"""
|
|
||||||
/user endpoints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..db import Connection
|
from ..db import Connection, schemas
|
||||||
from ..db.schemas import User, UserCapability, UserCreate
|
from . import _deps
|
||||||
from ._common import Responses, get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/user")
|
router = APIRouter(prefix="/user")
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
"""
|
|
||||||
Response model for issuing tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
@router.post("/authenticate", response_model=Token)
|
@router.post("/auth", response_model=Token)
|
||||||
async def login(
|
async def login(
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
db: Session | None = Depends(Connection.get),
|
db: Session | None = Depends(Connection.get),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
POST ./authenticate: Authenticate a user. Issues a bearer token.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# fail if not installed
|
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# try logging in
|
user = schemas.User.login(
|
||||||
user = User.authenticate(
|
|
||||||
db=db,
|
db=db,
|
||||||
name=form_data.username,
|
name=form_data.username,
|
||||||
password=form_data.password,
|
password=form_data.password,
|
||||||
crypt_context=await current_config.crypto.crypt_context,
|
crypt_context=await current_config.crypto.crypt_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# authentication failed
|
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
|
@ -54,62 +38,48 @@ async def login(
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# authentication succeeded
|
|
||||||
access_token = await current_config.jwt.create_token(user.name)
|
access_token = await current_config.jwt.create_token(user.name)
|
||||||
return {"access_token": access_token, "token_type": "bearer"}
|
return {"access_token": access_token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/current", response_model=User)
|
@router.get("/current", response_model=schemas.User)
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
current_user: User | None = Depends(get_current_user),
|
current_user: schemas.User | None = Depends(_deps.get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
GET ./current: Respond with the currently logged-in user.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/new",
|
"/new",
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: Responses.OK,
|
status.HTTP_200_OK: _deps.Responses.ok,
|
||||||
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
|
status.HTTP_400_BAD_REQUEST: _deps.Responses.not_installed,
|
||||||
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
|
status.HTTP_401_UNAUTHORIZED: _deps.Responses.needs_user,
|
||||||
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
|
status.HTTP_403_FORBIDDEN: _deps.Responses.needs_admin,
|
||||||
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
|
status.HTTP_409_CONFLICT: _deps.Responses.entry_exists,
|
||||||
},
|
},
|
||||||
response_model=User,
|
response_model=schemas.User,
|
||||||
)
|
)
|
||||||
async def add_user(
|
async def add_user(
|
||||||
user: UserCreate,
|
user: schemas.UserCreate,
|
||||||
current_config: Config | None = Depends(Config.load),
|
current_config: Config | None = Depends(Config.load),
|
||||||
current_user: User | None = Depends(get_current_user),
|
current_user: schemas.User | None = Depends(_deps.get_current_user),
|
||||||
db: Session | None = Depends(Connection.get),
|
db: Session | None = Depends(Connection.get),
|
||||||
):
|
):
|
||||||
"""
|
|
||||||
POST ./new: Create a new user in the database.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# fail if not installed
|
|
||||||
if current_config is None:
|
if current_config is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
# fail if not requested by an admin
|
|
||||||
if (current_user is None
|
if (current_user is None
|
||||||
or UserCapability.admin not in current_user.capabilities):
|
or schemas.UserCapability.admin not in current_user.capabilities):
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
# actually create the new user
|
new_user = schemas.User.create(
|
||||||
new_user = User.create(
|
|
||||||
db=db,
|
db=db,
|
||||||
user=user,
|
user=user,
|
||||||
crypt_context=await current_config.crypto.crypt_context,
|
crypt_context=await current_config.crypto.crypt_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# fail if creation was unsuccessful
|
|
||||||
if new_user is None:
|
if new_user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||||
|
|
||||||
# return the created user on success
|
|
||||||
return new_user
|
return new_user
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue