Merge pull request 'develop1' (#1) from Yavook.de/kiwi-vpn:develop into develop

Reviewed-on: penner/kiwi-vpn#1
This commit is contained in:
penner 2022-03-28 22:27:11 +00:00
commit d42ef089ff
18 changed files with 580 additions and 626 deletions

View file

@ -11,5 +11,6 @@
"editor.formatOnSave": true, "editor.formatOnSave": true,
"editor.codeActionsOnSave": { "editor.codeActionsOnSave": {
"source.organizeImports": true "source.organizeImports": true
} },
"git.closeDiffOnOperation": true
} }

View file

@ -20,8 +20,6 @@ from jose import JWTError, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field, validator from pydantic import BaseModel, BaseSettings, Field, validator
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
class Settings(BaseSettings): class Settings(BaseSettings):
@ -31,18 +29,21 @@ class Settings(BaseSettings):
production_mode: bool = False production_mode: bool = False
data_dir: Path = Path("./tmp") data_dir: Path = Path("./tmp")
config_file_name: Path = Path("config.json")
api_v1_prefix: str = "api/v1"
openapi_url: str = "/openapi.json" openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs" docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc" redoc_url: str | None = "/redoc"
@staticmethod @classmethod
@property
@functools.lru_cache @functools.lru_cache
def get() -> Settings: def _(cls) -> Settings:
return Settings() return cls()
@property @property
def config_file(self) -> Path: def config_file(self) -> Path:
return self.data_dir.joinpath("config.json") return self.data_dir.joinpath(self.config_file_name)
class DBType(Enum): class DBType(Enum):
@ -63,23 +64,20 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | None = Settings.get().data_dir.joinpath("vpn.db") database: str | None = Settings._.data_dir.joinpath("vpn.db")
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@property @property
async def db_engine(self) -> Engine: def uri(self) -> str:
""" """
Construct an SQLAlchemy engine Construct a database connection string
""" """
if self.type is DBType.sqlite: if self.type is DBType.sqlite:
# SQLite backend # SQLite backend
return create_engine( return f"sqlite:///{self.database}"
f"sqlite:///{self.database}",
connect_args={"check_same_thread": False},
)
elif self.type is DBType.mysql: elif self.type is DBType.mysql:
# MySQL backend # MySQL backend
@ -88,12 +86,9 @@ class DBConfig(BaseModel):
else: else:
args_str = "" args_str = ""
return create_engine( return (f"mysql+{self.mysql_driver}://"
f"mysql+{self.mysql_driver}://" f"{self.user}:{self.password}@{self.host}"
f"{self.user}:{self.password}@{self.host}" f"/{self.database}{args_str}")
f"/{self.database}{args_str}",
pool_recycle=3600,
)
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
@ -181,7 +176,7 @@ class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]
@property @property
async def crypt_context(self) -> CryptContext: def crypt_context(self) -> CryptContext:
return CryptContext( return CryptContext(
schemes=self.schemes, schemes=self.schemes,
deprecated="auto", deprecated="auto",
@ -197,23 +192,38 @@ class Config(BaseModel):
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
@staticmethod __singleton: Config | None = None
async def load() -> Config | None:
@classmethod
def load(cls) -> Config | None:
""" """
Load configuration from config file Load configuration from config file
""" """
if cls.__singleton is not None:
return cls.__singleton
try: try:
with open(Settings.get().config_file, "r") as config_file: with open(Settings._.config_file, "r") as config_file:
return Config.parse_obj(json.load(config_file)) cls.__singleton = Config.parse_obj(json.load(config_file))
return cls.__singleton
except FileNotFoundError: except FileNotFoundError:
return None return None
async def save(self) -> None: @classmethod
@property
def _(cls) -> Config | None:
"""
Shorthand for load()
"""
return cls.load()
def save(self) -> None:
""" """
Save configuration to config file Save configuration to config file
""" """
with open(Settings.get().config_file, "w") as config_file: with open(Settings._.config_file, "w") as config_file:
config_file.write(self.json(indent=2)) config_file.write(self.json(indent=2))

View file

@ -1,4 +1,20 @@
from . import models, schemas """
from .connection import Connection Package `db`: ORM and schemas for database content.
"""
__all__ = ["Connection", "models", "schemas"] from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead
from .user_capability import UserCapabilityType
__all__ = [
"Connection",
"Device",
"DeviceBase",
"DeviceCreate",
"User",
"UserBase",
"UserCreate",
"UserRead",
"UserCapabilityType",
]

View file

@ -1,75 +1,34 @@
""" """
Utilities for handling SQLAlchemy database connections. Database connection management.
""" """
from typing import Generator from sqlmodel import Session, SQLModel, create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection: class Connection:
""" """
Namespace for the database connection. Namespace for the database connection
""" """
engine: Engine | None = None engine = None
session_local: sessionmaker | None = None
@classmethod @classmethod
def connect(cls, engine: Engine) -> None: def connect(cls, connection_url: str) -> None:
""" """
Connect ORM to a database engine. Connect ORM to a database engine.
""" """
cls.engine = engine cls.engine = create_engine(connection_url)
cls.session_local = sessionmaker( SQLModel.metadata.create_all(cls.engine)
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod @classmethod
def use(cls) -> SessionManager | None: @property
def session(cls) -> Session | None:
""" """
Create an ORM session using a context manager. Create an ORM session using a context manager.
""" """
if cls.session_local is None: if cls.engine is None:
return None return None
return SessionManager(cls.session_local()) return Session(cls.engine)
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()

View file

@ -0,0 +1,101 @@
"""
Python representation of `device` table.
"""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint
from .connection import Connection
if TYPE_CHECKING:
from .user import User
class DeviceBase(SQLModel):
"""
Common to all representations of devices
"""
name: str
type: str
expiry: datetime | None
class DeviceCreate(DeviceBase):
"""
Representation of a newly created device
"""
owner_name: str | None
class DeviceRead(DeviceBase):
"""
Representation of a device read via the API
"""
owner_name: str | None
class Device(DeviceBase, table=True):
"""
Representation of `device` table
"""
__table_args__ = (UniqueConstraint(
"owner_name",
"name",
),)
id: int | None = Field(primary_key=True)
owner_name: str | None = Field(foreign_key="user.name")
# no idea, but "User" (in quotes) doesn't work here
# might be a future problem?
owner: User = Relationship(
back_populates="devices",
)
@classmethod
def create(cls, **kwargs) -> Device | None:
"""
Create a new device in the database.
"""
try:
with Connection.session as db:
device = cls.from_orm(DeviceCreate(**kwargs))
db.add(device)
db.commit()
db.refresh(device)
return device
except IntegrityError:
# device already existed
return None
def update(self) -> None:
"""
Update this device in the database.
"""
with Connection.session as db:
db.add(self)
db.commit()
db.refresh(self)
def delete(self) -> bool:
"""
Delete this device from the database.
"""
with Connection.session as db:
db.delete(self)
db.commit()

View file

@ -1,106 +0,0 @@
"""
SQLAlchemy representation of database contents.
"""
from __future__ import annotations
import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base()
class User(ORMBaseModel):
__tablename__ = "users"
name = Column(String, primary_key=True, index=True)
password = Column(String, nullable=False)
capabilities: list[UserCapability] = relationship(
"UserCapability", lazy="joined", cascade="all, delete-orphan"
)
certificates: list[Certificate] = relationship(
"Certificate", lazy="select", back_populates="owner"
)
distinguished_names: list[DistinguishedName] = relationship(
"DistinguishedName", lazy="select", back_populates="owner"
)
@classmethod
def load(cls, db: Session, name: str) -> User | None:
"""
Load user from database by name.
"""
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities"
user_name = Column(
String,
ForeignKey("users.name"),
primary_key=True,
index=True,
)
capability = Column(String, primary_key=True)
class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names"
id = Column(Integer, primary_key=True, autoincrement=True)
owner_name = Column(String, ForeignKey("users.name"))
cn_only = Column(Boolean, default=True, nullable=False)
country = Column(String(2))
state = Column(String)
city = Column(String)
organization = Column(String)
organizational_unit = Column(String)
email = Column(String)
common_name = Column(String, nullable=False)
owner: User = relationship(
"User", lazy="joined", back_populates="distinguished_names"
)
UniqueConstraint(
country,
state,
city,
organization,
organizational_unit,
email,
common_name,
)
class Certificate(ORMBaseModel):
__tablename__ = "certificates"
id = Column(Integer, primary_key=True, autoincrement=True)
owner_name = Column(String, ForeignKey("users.name"))
dn_id = Column(
Integer,
ForeignKey("distinguished_names.id"),
nullable=False,
)
expiry = Column(DateTime, default=datetime.datetime.now)
distinguished_name: DistinguishedName = relationship(
"DistinguishedName", lazy="joined"
)
owner: User = relationship(
"User", lazy="joined", back_populates="certificates"
)

View file

@ -1,275 +0,0 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any
from passlib.context import CryptContext
from pydantic import BaseModel, Field, constr, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from . import models
##########
# table: distinguished_names
##########
class DistinguishedNameBase(BaseModel):
cn_only: bool
country: constr(max_length=2) | None
state: str | None
city: str | None
organization: str | None
organizational_unit: str | None
email: str | None
common_name: str
class DistinguishedNameCreate(DistinguishedNameBase):
pass
class DistinguishedName(DistinguishedNameBase):
class Config:
orm_mode = True
@classmethod
def create(
cls,
db: Session,
dn: DistinguishedNameCreate,
owner: User,
) -> User | None:
"""
Create a new distinguished name in the database.
"""
try:
db_owner = models.User.load(
db=db,
name=owner.name,
)
dn = models.DistinguishedName(
cn_only=dn.cn_only,
country=dn.country,
state=dn.state,
city=dn.city,
organization=dn.organization,
organizational_unit=dn.organizational_unit,
email=dn.email,
common_name=dn.common_name,
owner=db_owner,
)
db.add(dn)
db.commit()
db.refresh(dn)
return cls.from_orm(dn)
except IntegrityError:
# distinguished name already existed
pass
##########
# table: certificates
##########
class CertificateBase(BaseModel):
expiry: datetime
class CertificateCreate(CertificateBase):
pass
class Certificate(CertificateBase):
distinguished_name: DistinguishedName
class Config:
orm_mode = True
##########
# table: user_capabilities
##########
class UserCapability(Enum):
admin = "admin"
def __repr__(self) -> str:
return self.value
@classmethod
def from_value(cls, value) -> UserCapability:
"""
Create UserCapability from various formats
"""
if isinstance(value, cls):
# value is already a UserCapability, use that
return value
elif isinstance(value, models.UserCapability):
# create from db format
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
##########
# table: users
##########
class UserBase(BaseModel):
name: str
class UserCreate(UserBase):
password: str
class User(UserBase):
capabilities: list[UserCapability] = []
distinguished_names: list[DistinguishedName] = Field(
default=[], repr=False
)
certificates: list[Certificate] = Field(
default=[], repr=False
)
class Config:
orm_mode = True
@validator("capabilities", pre=True)
@classmethod
def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
"""
Import the capabilities from various formats
"""
return [
UserCapability.from_value(capability)
for capability in value
]
@classmethod
def from_db(
cls,
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
if (db_user := models.User.load(db, name)) is None:
return None
return cls.from_orm(db_user)
@classmethod
def create(
cls,
db: Session,
user: UserCreate,
crypt_context: CryptContext,
) -> User | None:
"""
Create a new user in the database.
"""
try:
user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[],
)
db.add(user)
db.commit()
db.refresh(user)
return cls.from_orm(user)
except IntegrityError:
# user already existed
pass
def is_admin(self) -> bool:
return UserCapability.admin in self.capabilities
def authenticate(
self,
db: Session,
password: str,
crypt_context: CryptContext,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
if (db_user := models.User.load(db, self.name)) is None:
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return False
if not crypt_context.verify(password, db_user.password):
# password hash mismatch
return False
self.from_orm(db_user)
return True
def update(
self,
db: Session,
) -> None:
"""
Update this user in the database.
"""
old_dbuser = models.User.load(db, self.name)
old_user = self.from_orm(old_dbuser)
for capability in self.capabilities:
if capability not in old_user.capabilities:
old_dbuser.capabilities.append(
models.UserCapability(capability=capability.value)
)
for capability in old_dbuser.capabilities:
if UserCapability.from_value(capability) not in self.capabilities:
db.delete(capability)
db.commit()
def delete(
self,
db: Session,
) -> bool:
"""
Delete this user from the database.
"""
if (db_user := models.User.load(db, self.name)) is None:
# nonexistent user
return False
db.delete(db_user)
db.commit()
return True

199
api/kiwi_vpn_api/db/user.py Normal file
View file

@ -0,0 +1,199 @@
"""
Python representation of `user` table.
"""
from __future__ import annotations
from typing import Any, Sequence
from pydantic import root_validator
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Relationship, SQLModel
from ..config import Config
from .connection import Connection
from .device import Device
from .user_capability import UserCapability, UserCapabilityType
class UserBase(SQLModel):
"""
Common to all representations of users
"""
name: str = Field(primary_key=True)
email: str | None = Field(default=None)
country: str | None = Field(default=None)
state: str | None = Field(default=None)
city: str | None = Field(default=None)
organization: str | None = Field(default=None)
organizational_unit: str | None = Field(default=None)
class UserCreate(UserBase):
"""
Representation of a newly created user
"""
password: str | None = Field(default=None)
password_clear: str | None = Field(default=None)
@root_validator
@classmethod
def hash_password(cls, values: dict[str, Any]) -> dict[str, Any]:
"""
Ensure the `password` value of this user gets set.
"""
if (values.get("password")) is not None:
# password is set
return values
if (password_clear := values.get("password_clear")) is None:
raise ValueError("No password to hash")
if (current_config := Config._) is None:
raise ValueError("Not configured")
values["password"] = current_config.crypto.crypt_context.hash(
password_clear)
return values
class UserRead(UserBase):
"""
Representation of a user read via the API
"""
pass
class User(UserBase, table=True):
"""
Representation of `user` table
"""
password: str
capabilities: list[UserCapability] = Relationship(
back_populates="user",
sa_relationship_kwargs={
"lazy": "joined",
"cascade": "all, delete-orphan",
},
)
devices: list[Device] = Relationship(
back_populates="owner",
)
@classmethod
def create(cls, **kwargs) -> User | None:
"""
Create a new user in the database.
"""
try:
with Connection.session as db:
user = cls.from_orm(UserCreate(**kwargs))
db.add(user)
db.commit()
db.refresh(user)
return user
except IntegrityError:
# user already existed
return None
@classmethod
def get(cls, name: str) -> User | None:
"""
Load user from database by name.
"""
with Connection.session as db:
return db.get(cls, name)
@classmethod
def authenticate(
cls,
name: str,
password: str,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
crypt_context = Config._.crypto.crypt_context
if (user := cls.get(name)) is None:
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return None
if not crypt_context.verify(password, user.password):
# password hash mismatch
return None
return user
def update(self) -> None:
"""
Update this user in the database.
"""
with Connection.session as db:
db.add(self)
db.commit()
db.refresh(self)
def delete(self) -> None:
"""
Delete this user from the database.
"""
with Connection.session as db:
db.delete(self)
db.commit()
def get_capabilities(self) -> set[UserCapabilityType]:
"""
Return the capabilities of this user.
"""
return set(
capability._
for capability in self.capabilities
)
def can(
self,
capability: UserCapabilityType,
) -> bool:
"""
Check if this user has a capability.
"""
return (
capability in self.get_capabilities()
or UserCapabilityType.admin in self.get_capabilities()
)
def set_capabilities(
self,
capabilities: Sequence[UserCapabilityType],
) -> None:
"""
Change the capabilities of this user.
"""
self.capabilities = [
UserCapability(
user_name=self.name,
capability_name=capability.value,
) for capability in capabilities
]

View file

@ -0,0 +1,56 @@
"""
Python representation of `user_capability` table.
"""
from enum import Enum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
class UserCapabilityType(Enum):
"""
Allowed values for capabilities
"""
admin = "admin"
login = "login"
issue = "issue"
renew = "renew"
def __repr__(self) -> str:
return self.value
class UserCapabilityBase(SQLModel):
"""
Common to all representations of capabilities
"""
capability_name: str = Field(primary_key=True)
@property
def _(self) -> UserCapabilityType:
"""
Transform into a `Capability`.
"""
return UserCapabilityType(self.capability_name)
def __repr__(self) -> str:
return self.capability_name
class UserCapability(UserCapabilityBase, table=True):
"""
Representation of `user_capability` table
"""
user_name: str = Field(primary_key=True, foreign_key="user.name")
user: "User" = Relationship(
back_populates="capabilities",
)

View file

@ -13,13 +13,9 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db import Connection from .db import Connection, User
from .db.schemas import User
from .routers import main_router from .routers import main_router
settings = Settings.get()
app = FastAPI( app = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",
@ -31,30 +27,29 @@ app = FastAPI(
"name": "MIT License", "name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php", "url": "https://opensource.org/licenses/mit-license.php",
}, },
openapi_url=settings.openapi_url, openapi_url=Settings._.openapi_url,
docs_url=settings.docs_url if not settings.production_mode else None, docs_url=Settings._.docs_url if not Settings._.production_mode else None,
redoc_url=settings.redoc_url if not settings.production_mode else None, redoc_url=Settings._.redoc_url if not Settings._.production_mode else None,
) )
app.include_router(main_router) app.include_router(main_router, prefix=f"/{Settings._.api_v1_prefix}")
@app.on_event("startup") @app.on_event("startup")
async def on_startup() -> None: async def on_startup() -> None:
# check if configured # check if configured
if (current_config := await Config.load()) is not None: if (current_config := Config._) is not None:
# connect to database # connect to database
Connection.connect(await current_config.db.db_engine) Connection.connect(current_config.db.uri)
# some testing # some testing
with Connection.use() as db: print(User.get("admin"))
print(User.from_db(db, "admin")) print(User.get("nonexistent"))
print(User.from_db(db, "nonexistent"))
def main() -> None: def main() -> None:
uvicorn.run( uvicorn.run(
"kiwi_vpn_api.main:app", app="kiwi_vpn_api.main:app",
host="0.0.0.0", host="0.0.0.0",
port=8000, port=8000,
reload=True, reload=True,

View file

@ -1,10 +1,18 @@
"""
Package `routers`: Each module contains the path operations for their prefixes.
This file: Main API router definition.
"""
from fastapi import APIRouter from fastapi import APIRouter
from . import admin, user from . import admin, user
main_router = APIRouter(prefix="/api/v1") main_router = APIRouter()
main_router.include_router(admin.router) main_router.include_router(admin.router)
main_router.include_router(user.router) main_router.include_router(user.router)
__all__ = ["main_router"] __all__ = [
"main_router",
]

View file

@ -5,13 +5,13 @@ Common dependencies for routers.
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from ..config import Config from ..config import Config, Settings
from ..db import Connection from ..db import User, UserCapabilityType
from ..db.schemas import User
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") oauth2_scheme = OAuth2PasswordBearer(
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
)
class Responses: class Responses:
@ -56,7 +56,6 @@ class Responses:
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session | None = Depends(Connection.get),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
) -> User | None: ) -> User | None:
""" """
@ -68,13 +67,11 @@ async def get_current_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
username = await current_config.jwt.decode_token(token) username = await current_config.jwt.decode_token(token)
user = User.from_db(db, username)
return user return User.get(username)
async def get_current_user_if_exists( async def get_current_user_if_exists(
current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
) -> User: ) -> User:
""" """
@ -89,7 +86,6 @@ async def get_current_user_if_exists(
async def get_current_user_if_admin( async def get_current_user_if_admin(
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -97,7 +93,7 @@ async def get_current_user_if_admin(
""" """
# fail if not requested by an admin # fail if not requested by an admin
if not current_user.is_admin(): if not current_user.can(UserCapabilityType.admin):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -105,7 +101,6 @@ async def get_current_user_if_admin(
async def get_current_user_if_admin_or_self( async def get_current_user_if_admin_or_self(
user_name: str, user_name: str,
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -116,7 +111,8 @@ async def get_current_user_if_admin_or_self(
""" """
# fail if not requested by an admin or self # fail if not requested by an admin or self
if not (current_user.is_admin() or current_user.name == user_name): if not (current_user.can(UserCapabilityType.admin)
or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user

View file

@ -4,49 +4,68 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import select
from ..config import Config from ..config import Config
from ..db import Connection from ..db import Connection, User, UserCapabilityType, UserCreate
from ..db.schemas import User, UserCapability, UserCreate from ._common import Responses, get_current_user_if_admin
from ._common import Responses, get_current_user
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@router.put( @router.put(
"/install", "/install/config",
responses={ responses={
status.HTTP_200_OK: Responses.OK, status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.INSTALLED, status.HTTP_400_BAD_REQUEST: Responses.INSTALLED,
}, },
) )
async def install( async def initial_configure(
config: Config, config: Config,
admin_user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
""" """
PUT ./install: Install `kiwi-vpn`. PUT ./install/config: Configure `kiwi-vpn`.
""" """
# fail if already installed # fail if already configured
if current_config is not None: if current_config is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# create config file, connect to database # create config file, connect to database
await config.save() config.save()
Connection.connect(await config.db.db_engine) Connection.connect(current_config.db.uri)
@router.put(
"/install/admin",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
)
async def create_initial_admin(
admin_user: UserCreate,
current_config: Config | None = Depends(Config.load),
):
"""
PUT ./install/admin: Create the first administrative user.
"""
# fail if not configured
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if any user exists
with Connection.session as db:
if db.exec(select(User).limit(1)).first() is not None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# create an administrative user # create an administrative user
with Connection.use() as db: new_user = User.create(**admin_user.dict())
new_user = User.create( new_user.set_capabilities([UserCapabilityType.admin])
db=db, new_user.update()
user=admin_user,
crypt_context=await config.crypto.crypt_context,
)
new_user.capabilities.append(UserCapability.admin)
new_user.update(db)
@router.put( @router.put(
@ -59,23 +78,13 @@ async def install(
}, },
) )
async def set_config( async def set_config(
new_config: Config, config: Config,
current_config: Config | None = Depends(Config.load), _: User = Depends(get_current_user_if_admin),
current_user: User | None = Depends(get_current_user),
): ):
""" """
PUT ./config: Edit `kiwi-vpn` main config. PUT ./config: Edit `kiwi-vpn` main config.
""" """
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None
or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database # update config file, reconnect to database
await new_config.save() config.save()
Connection.connect(await new_config.db.db_engine) Connection.connect(config.db.uri)

View file

@ -1,58 +0,0 @@
"""
/dn endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..db import Connection
from ..db.schemas import DistinguishedName, DistinguishedNameCreate, User
from ._common import Responses, get_current_user_if_admin_or_self
router = APIRouter(prefix="/dn")
@router.post(
"",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
)
async def add_distinguished_name(
user_name: str,
distinguished_name: DistinguishedNameCreate,
_: User = Depends(get_current_user_if_admin_or_self),
db: Session | None = Depends(Connection.get),
):
"""
POST ./: Create a new distinguished name in the database.
"""
owner = User.from_db(
db=db,
name=user_name,
)
# fail if owner doesn't exist
if owner is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# actually create the new user
new_dn = DistinguishedName.create(
db=db,
dn=distinguished_name,
owner=owner,
)
# fail if creation was unsuccessful
if new_dn is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# return the created user on success
return new_dn

View file

@ -5,11 +5,9 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Connection from ..db import User, UserCapabilityType, UserCreate, UserRead
from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_user, get_current_user_if_admin from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -28,7 +26,6 @@ class Token(BaseModel):
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./authenticate: Authenticate a user. Issues a bearer token. POST ./authenticate: Authenticate a user. Issues a bearer token.
@ -39,12 +36,10 @@ async def login(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# try logging in # try logging in
user = User(name=form_data.username) if not (user := User.authenticate(
if not user.authenticate( name=form_data.username,
db=db,
password=form_data.password, password=form_data.password,
crypt_context=await current_config.crypto.crypt_context, )):
):
# authentication failed # authentication failed
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -52,12 +47,16 @@ async def login(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
if not user.can(UserCapabilityType.login):
# user cannot login
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# authentication succeeded # authentication succeeded
access_token = await current_config.jwt.create_token(user.name) access_token = await current_config.jwt.create_token(user.name)
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.get("/current", response_model=User) @router.get("/current", response_model=UserRead)
async def get_current_user( async def get_current_user(
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
): ):
@ -81,20 +80,15 @@ async def get_current_user(
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
current_config: Config | None = Depends(Config.load),
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./: Create a new user in the database. POST ./: Create a new user in the database.
""" """
# actually create the new user # actually create the new user
new_user = User.create( new_user = User.create(**user.dict())
db=db, new_user.set_capabilities([UserCapabilityType.login])
user=user,
crypt_context=await current_config.crypto.crypt_context,
)
# fail if creation was unsuccessful # fail if creation was unsuccessful
if new_user is None: if new_user is None:
@ -118,22 +112,21 @@ async def add_user(
async def remove_user( async def remove_user(
user_name: str, user_name: str,
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}: Remove a user from the database. DELETE ./{user_name}: Remove a user from the database.
""" """
# get the user # get the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name,
)
# fail if deletion was unsuccessful # fail if user not found
if user is None or not user.delete(db): if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# delete user
user.delete()
@router.post( @router.post(
"/{user_name}/capabilities", "/{user_name}/capabilities",
@ -146,22 +139,19 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[UserCapabilityType],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./{user_name}/capabilities: Add capabilities to a user. POST ./{user_name}/capabilities: Add capabilities to a user.
""" """
# get and change the user # get and change the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name,
)
user.capabilities.extend(capabilities) user.set_capabilities(user.get_capabilities() | set(capabilities))
user.update(db)
user.update()
@router.delete( @router.delete(
@ -175,21 +165,16 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[UserCapability], capabilities: list[UserCapabilityType],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}/capabilities: Remove capabilities from a user. DELETE ./{user_name}/capabilities: Remove capabilities from a user.
""" """
# get and change the user # get and change the user
user = User.from_db( user = User.get(user_name)
db=db,
name=user_name,
)
for capability in capabilities: user.set_capabilities(user.get_capabilities() - set(capabilities))
user.capabilities.remove(capability)
user.update(db) user.update()

25
api/plan.md Normal file
View file

@ -0,0 +1,25 @@
## Server props
- default DN parts: country, state, city, org, OU
- "customizable" flags for DN parts
- flag: use client-to-client
- force cipher, tls-cipher, auth params
- server name
- default certification length
- default certificate algo
## User props
- username
- password
- custom DN parts: country, state, city, org, OU
- email
## User caps
- admin: administrator
- login: can log into the web interface
- issue: can certify own devices without approval
- renew: can renew certificates for own devices
## Device props
- name
- type (icon)
- expiry

34
api/poetry.lock generated
View file

@ -428,6 +428,30 @@ postgresql_psycopg2cffi = ["psycopg2cffi"]
pymysql = ["pymysql (<1)", "pymysql"] pymysql = ["pymysql (<1)", "pymysql"]
sqlcipher = ["sqlcipher3-binary"] sqlcipher = ["sqlcipher3-binary"]
[[package]]
name = "sqlalchemy2-stubs"
version = "0.0.2a21"
description = "Typing Stubs for SQLAlchemy 1.4"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
typing-extensions = ">=3.7.4"
[[package]]
name = "sqlmodel"
version = "0.0.6"
description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness."
category = "main"
optional = false
python-versions = ">=3.6.1,<4.0.0"
[package.dependencies]
pydantic = ">=1.8.2,<2.0.0"
SQLAlchemy = ">=1.4.17,<1.5.0"
sqlalchemy2-stubs = "*"
[[package]] [[package]]
name = "starlette" name = "starlette"
version = "0.17.1" version = "0.17.1"
@ -477,7 +501,7 @@ standard = ["websockets (>=10.0)", "httptools (>=0.4.0)", "watchgod (>=0.6)", "p
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "432d2933102f8a0091cec1b5484944a0211ca74c5dc9b65877d99d7bd160e4bb" content-hash = "a580f9fe4c68667e4cbdf385ac11d5c7a2925e3c990b7164faa922ec8b6f9555"
[metadata.files] [metadata.files]
anyio = [ anyio = [
@ -834,6 +858,14 @@ sqlalchemy = [
{file = "SQLAlchemy-1.4.32-cp39-cp39-win_amd64.whl", hash = "sha256:b3f1d9b3aa09ab9adc7f8c4b40fc3e081eb903054c9a6f9ae1633fe15ae503b4"}, {file = "SQLAlchemy-1.4.32-cp39-cp39-win_amd64.whl", hash = "sha256:b3f1d9b3aa09ab9adc7f8c4b40fc3e081eb903054c9a6f9ae1633fe15ae503b4"},
{file = "SQLAlchemy-1.4.32.tar.gz", hash = "sha256:6fdd2dc5931daab778c2b65b03df6ae68376e028a3098eb624d0909d999885bc"}, {file = "SQLAlchemy-1.4.32.tar.gz", hash = "sha256:6fdd2dc5931daab778c2b65b03df6ae68376e028a3098eb624d0909d999885bc"},
] ]
sqlalchemy2-stubs = [
{file = "sqlalchemy2-stubs-0.0.2a21.tar.gz", hash = "sha256:207e3d8a36fc032d325f4eec89e0c6760efe81d07e978513d8c9b14f108dcd0c"},
{file = "sqlalchemy2_stubs-0.0.2a21-py3-none-any.whl", hash = "sha256:bd4a3d5ca7ff9d01b2245e1b26304d6b2ec4daf43a01faf40db9e09245679433"},
]
sqlmodel = [
{file = "sqlmodel-0.0.6-py3-none-any.whl", hash = "sha256:c5fd8719e09da348cd32ce2a5b6a44f289d3029fa8f1c9818229b6f34f1201b4"},
{file = "sqlmodel-0.0.6.tar.gz", hash = "sha256:3b4f966b9671b24d85529d274e6c4dbc7753b468e35d2d6a40bd75cad1f66813"},
]
starlette = [ starlette = [
{file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"}, {file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"},
{file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"}, {file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"},

View file

@ -13,6 +13,7 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"}
passlib = {extras = ["argon2", "bcrypt"], version = "^1.7.4"} passlib = {extras = ["argon2", "bcrypt"], version = "^1.7.4"}
SQLAlchemy = "^1.4.32" SQLAlchemy = "^1.4.32"
pyOpenSSL = "^22.0.0" pyOpenSSL = "^22.0.0"
sqlmodel = "^0.0.6"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^7.1.0" pytest = "^7.1.0"