Compare commits

..

No commits in common. "5b623e885c38f338df7a7b3bc1bb402d866e5a80" and "71ac02e5d770ebebefbd0e79e260dcb66432ec70" have entirely different histories.

17 changed files with 622 additions and 559 deletions

View file

@ -20,6 +20,8 @@ from jose import JWTError, jwt
from jose.constants import ALGORITHMS from jose.constants import ALGORITHMS
from passlib.context import CryptContext from passlib.context import CryptContext
from pydantic import BaseModel, BaseSettings, Field, validator from pydantic import BaseModel, BaseSettings, Field, validator
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
class Settings(BaseSettings): class Settings(BaseSettings):
@ -29,21 +31,18 @@ class Settings(BaseSettings):
production_mode: bool = False production_mode: bool = False
data_dir: Path = Path("./tmp") data_dir: Path = Path("./tmp")
config_file_name: Path = Path("config.json")
api_v1_prefix: str = "api/v1"
openapi_url: str = "/openapi.json" openapi_url: str = "/openapi.json"
docs_url: str | None = "/docs" docs_url: str | None = "/docs"
redoc_url: str | None = "/redoc" redoc_url: str | None = "/redoc"
@classmethod @staticmethod
@property
@functools.lru_cache @functools.lru_cache
def _(cls) -> Settings: def get() -> Settings:
return cls() return Settings()
@property @property
def config_file(self) -> Path: def config_file(self) -> Path:
return self.data_dir.joinpath(self.config_file_name) return self.data_dir.joinpath("config.json")
class DBType(Enum): class DBType(Enum):
@ -64,20 +63,23 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | None = Settings._.data_dir.joinpath("vpn.db") database: str | None = Settings.get().data_dir.joinpath("vpn.db")
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@property @property
def uri(self) -> str: async def db_engine(self) -> Engine:
""" """
Construct a database connection string Construct an SQLAlchemy engine
""" """
if self.type is DBType.sqlite: if self.type is DBType.sqlite:
# SQLite backend # SQLite backend
return f"sqlite:///{self.database}" return create_engine(
f"sqlite:///{self.database}",
connect_args={"check_same_thread": False},
)
elif self.type is DBType.mysql: elif self.type is DBType.mysql:
# MySQL backend # MySQL backend
@ -86,9 +88,12 @@ class DBConfig(BaseModel):
else: else:
args_str = "" args_str = ""
return (f"mysql+{self.mysql_driver}://" return create_engine(
f"mysql+{self.mysql_driver}://"
f"{self.user}:{self.password}@{self.host}" f"{self.user}:{self.password}@{self.host}"
f"/{self.database}{args_str}") f"/{self.database}{args_str}",
pool_recycle=3600,
)
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
@ -176,7 +181,7 @@ class CryptoConfig(BaseModel):
schemes: list[str] = ["bcrypt"] schemes: list[str] = ["bcrypt"]
@property @property
def crypt_context(self) -> CryptContext: async def crypt_context(self) -> CryptContext:
return CryptContext( return CryptContext(
schemes=self.schemes, schemes=self.schemes,
deprecated="auto", deprecated="auto",
@ -192,38 +197,23 @@ class Config(BaseModel):
jwt: JWTConfig = Field(default_factory=JWTConfig) jwt: JWTConfig = Field(default_factory=JWTConfig)
crypto: CryptoConfig = Field(default_factory=CryptoConfig) crypto: CryptoConfig = Field(default_factory=CryptoConfig)
__singleton: Config | None = None @staticmethod
async def load() -> Config | None:
@classmethod
def load(cls) -> Config | None:
""" """
Load configuration from config file Load configuration from config file
""" """
if cls.__singleton is not None:
return cls.__singleton
try: try:
with open(Settings._.config_file, "r") as config_file: with open(Settings.get().config_file, "r") as config_file:
cls.__singleton = Config.parse_obj(json.load(config_file)) return Config.parse_obj(json.load(config_file))
return cls.__singleton
except FileNotFoundError: except FileNotFoundError:
return None return None
@classmethod async def save(self) -> None:
@property
def _(cls) -> Config | None:
"""
Shorthand for load()
"""
return cls.load()
def save(self) -> None:
""" """
Save configuration to config file Save configuration to config file
""" """
with open(Settings._.config_file, "w") as config_file: with open(Settings.get().config_file, "w") as config_file:
config_file.write(self.json(indent=2)) config_file.write(self.json(indent=2))

View file

@ -1,11 +1,4 @@
""" from . import models, schemas
Package `db`: ORM and schemas for database content.
"""
from .connection import Connection from .connection import Connection
from .device import Device, DeviceBase, DeviceCreate
from .user import User, UserBase, UserCreate, UserRead
from .user_capability import Capability
__all__ = ["Capability", "Connection", "Device", "DeviceBase", "DeviceCreate", __all__ = ["Connection", "models", "schemas"]
"User", "UserBase", "UserCreate", "UserRead"]

View file

@ -1,34 +1,75 @@
""" """
Database connection management. Utilities for handling SQLAlchemy database connections.
""" """
from sqlmodel import Session, SQLModel, create_engine from typing import Generator
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class SessionManager:
"""
Simple context manager for an ORM session.
"""
__session: Session
def __init__(self, session: Session) -> None:
self.__session = session
def __enter__(self) -> Session:
return self.__session
def __exit__(self, *args) -> None:
self.__session.close()
class Connection: class Connection:
""" """
Namespace for the database connection Namespace for the database connection.
""" """
engine = None engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod @classmethod
def connect(cls, connection_url: str) -> None: def connect(cls, engine: Engine) -> None:
""" """
Connect ORM to a database engine. Connect ORM to a database engine.
""" """
cls.engine = create_engine(connection_url) cls.engine = engine
SQLModel.metadata.create_all(cls.engine) cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod @classmethod
@property def use(cls) -> SessionManager | None:
def session(cls) -> Session | None:
""" """
Create an ORM session using a context manager. Create an ORM session using a context manager.
""" """
if cls.engine is None: if cls.session_local is None:
return None return None
return Session(cls.engine) return SessionManager(cls.session_local())
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
"""
Create an ORM session using a FastAPI compatible async generator.
"""
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()

View file

@ -1,103 +0,0 @@
"""
Python representation of `devices` table.
"""
from __future__ import annotations
from datetime import datetime
from typing import TYPE_CHECKING
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint
from .connection import Connection
if TYPE_CHECKING:
from .user import User
class DeviceBase(SQLModel):
"""
Common to all representations of devices
"""
name: str
type: str
expiry: datetime | None
class DeviceCreate(DeviceBase):
"""
Representation of a newly created device
"""
owner_name: str | None
class DeviceRead(DeviceBase):
"""
Representation of a device read via the API
"""
owner_name: str | None
class Device(DeviceBase, table=True):
"""
Representation of `devices` table
"""
__tablename__ = "devices"
__table_args__ = (UniqueConstraint(
"owner_name",
"name",
),)
id: int | None = Field(primary_key=True)
owner_name: str | None = Field(foreign_key="users.name")
# no idea, but "User" (in quotes) doesn't work here
# might be a future problem?
owner: User = Relationship(
back_populates="devices",
)
@classmethod
def create(cls, **kwargs) -> Device | None:
"""
Create a new device in the database.
"""
try:
with Connection.session as db:
device = cls.from_orm(DeviceCreate(**kwargs))
db.add(device)
db.commit()
db.refresh(device)
return device
except IntegrityError:
# device already existed
return None
def update(self) -> None:
"""
Update this device in the database.
"""
with Connection.session as db:
db.add(self)
db.commit()
db.refresh(self)
def delete(self) -> bool:
"""
Delete this device from the database.
"""
with Connection.session as db:
db.delete(self)
db.commit()

View file

@ -0,0 +1,106 @@
"""
SQLAlchemy representation of database contents.
"""
from __future__ import annotations
import datetime
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String,
UniqueConstraint)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, relationship
ORMBaseModel = declarative_base()
class User(ORMBaseModel):
__tablename__ = "users"
name = Column(String, primary_key=True, index=True)
password = Column(String, nullable=False)
capabilities: list[UserCapability] = relationship(
"UserCapability", lazy="joined", cascade="all, delete-orphan"
)
certificates: list[Certificate] = relationship(
"Certificate", lazy="select", back_populates="owner"
)
distinguished_names: list[DistinguishedName] = relationship(
"DistinguishedName", lazy="select", back_populates="owner"
)
@classmethod
def load(cls, db: Session, name: str) -> User | None:
"""
Load user from database by name.
"""
return (db
.query(User)
.filter(User.name == name)
.first())
class UserCapability(ORMBaseModel):
__tablename__ = "user_capabilities"
user_name = Column(
String,
ForeignKey("users.name"),
primary_key=True,
index=True,
)
capability = Column(String, primary_key=True)
class DistinguishedName(ORMBaseModel):
__tablename__ = "distinguished_names"
id = Column(Integer, primary_key=True, autoincrement=True)
owner_name = Column(String, ForeignKey("users.name"))
cn_only = Column(Boolean, default=True, nullable=False)
country = Column(String(2))
state = Column(String)
city = Column(String)
organization = Column(String)
organizational_unit = Column(String)
email = Column(String)
common_name = Column(String, nullable=False)
owner: User = relationship(
"User", lazy="joined", back_populates="distinguished_names"
)
UniqueConstraint(
country,
state,
city,
organization,
organizational_unit,
email,
common_name,
)
class Certificate(ORMBaseModel):
__tablename__ = "certificates"
id = Column(Integer, primary_key=True, autoincrement=True)
owner_name = Column(String, ForeignKey("users.name"))
dn_id = Column(
Integer,
ForeignKey("distinguished_names.id"),
nullable=False,
)
expiry = Column(DateTime, default=datetime.datetime.now)
distinguished_name: DistinguishedName = relationship(
"DistinguishedName", lazy="joined"
)
owner: User = relationship(
"User", lazy="joined", back_populates="certificates"
)

View file

@ -0,0 +1,275 @@
"""
Pydantic representation of database contents.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any
from passlib.context import CryptContext
from pydantic import BaseModel, Field, constr, validator
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from . import models
##########
# table: distinguished_names
##########
class DistinguishedNameBase(BaseModel):
cn_only: bool
country: constr(max_length=2) | None
state: str | None
city: str | None
organization: str | None
organizational_unit: str | None
email: str | None
common_name: str
class DistinguishedNameCreate(DistinguishedNameBase):
pass
class DistinguishedName(DistinguishedNameBase):
class Config:
orm_mode = True
@classmethod
def create(
cls,
db: Session,
dn: DistinguishedNameCreate,
owner: User,
) -> User | None:
"""
Create a new distinguished name in the database.
"""
try:
db_owner = models.User.load(
db=db,
name=owner.name,
)
dn = models.DistinguishedName(
cn_only=dn.cn_only,
country=dn.country,
state=dn.state,
city=dn.city,
organization=dn.organization,
organizational_unit=dn.organizational_unit,
email=dn.email,
common_name=dn.common_name,
owner=db_owner,
)
db.add(dn)
db.commit()
db.refresh(dn)
return cls.from_orm(dn)
except IntegrityError:
# distinguished name already existed
pass
##########
# table: certificates
##########
class CertificateBase(BaseModel):
expiry: datetime
class CertificateCreate(CertificateBase):
pass
class Certificate(CertificateBase):
distinguished_name: DistinguishedName
class Config:
orm_mode = True
##########
# table: user_capabilities
##########
class UserCapability(Enum):
admin = "admin"
def __repr__(self) -> str:
return self.value
@classmethod
def from_value(cls, value) -> UserCapability:
"""
Create UserCapability from various formats
"""
if isinstance(value, cls):
# value is already a UserCapability, use that
return value
elif isinstance(value, models.UserCapability):
# create from db format
return cls(value.capability)
else:
# create from string representation
return cls(str(value))
##########
# table: users
##########
class UserBase(BaseModel):
name: str
class UserCreate(UserBase):
password: str
class User(UserBase):
capabilities: list[UserCapability] = []
distinguished_names: list[DistinguishedName] = Field(
default=[], repr=False
)
certificates: list[Certificate] = Field(
default=[], repr=False
)
class Config:
orm_mode = True
@validator("capabilities", pre=True)
@classmethod
def unify_capabilities(cls, value: list[Any]) -> list[UserCapability]:
"""
Import the capabilities from various formats
"""
return [
UserCapability.from_value(capability)
for capability in value
]
@classmethod
def from_db(
cls,
db: Session,
name: str,
) -> User | None:
"""
Load user from database by name.
"""
if (db_user := models.User.load(db, name)) is None:
return None
return cls.from_orm(db_user)
@classmethod
def create(
cls,
db: Session,
user: UserCreate,
crypt_context: CryptContext,
) -> User | None:
"""
Create a new user in the database.
"""
try:
user = models.User(
name=user.name,
password=crypt_context.hash(user.password),
capabilities=[],
)
db.add(user)
db.commit()
db.refresh(user)
return cls.from_orm(user)
except IntegrityError:
# user already existed
pass
def is_admin(self) -> bool:
return UserCapability.admin in self.capabilities
def authenticate(
self,
db: Session,
password: str,
crypt_context: CryptContext,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
if (db_user := models.User.load(db, self.name)) is None:
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return False
if not crypt_context.verify(password, db_user.password):
# password hash mismatch
return False
self.from_orm(db_user)
return True
def update(
self,
db: Session,
) -> None:
"""
Update this user in the database.
"""
old_dbuser = models.User.load(db, self.name)
old_user = self.from_orm(old_dbuser)
for capability in self.capabilities:
if capability not in old_user.capabilities:
old_dbuser.capabilities.append(
models.UserCapability(capability=capability.value)
)
for capability in old_dbuser.capabilities:
if UserCapability.from_value(capability) not in self.capabilities:
db.delete(capability)
db.commit()
def delete(
self,
db: Session,
) -> bool:
"""
Delete this user from the database.
"""
if (db_user := models.User.load(db, self.name)) is None:
# nonexistent user
return False
db.delete(db_user)
db.commit()
return True

View file

@ -1,192 +0,0 @@
"""
Python representation of `users` table.
"""
from __future__ import annotations
from typing import Any
from pydantic import root_validator
from sqlalchemy.exc import IntegrityError
from sqlmodel import Field, Relationship, SQLModel
from ..config import Config
from .connection import Connection
from .device import Device
from .user_capability import Capability, UserCapability
class UserBase(SQLModel):
"""
Common to all representations of users
"""
name: str = Field(primary_key=True)
email: str | None = Field(default=None)
country: str | None = Field(default=None)
state: str | None = Field(default=None)
city: str | None = Field(default=None)
organization: str | None = Field(default=None)
organizational_unit: str | None = Field(default=None)
class UserCreate(UserBase):
"""
Representation of a newly created user
"""
password: str | None = Field(default=None)
password_clear: str | None = Field(default=None)
@root_validator
@classmethod
def hash_password(cls, values: dict[str, Any]) -> dict[str, Any]:
"""
Ensure the `password` value of this user gets set.
"""
if (values.get("password")) is not None:
# password is set
return values
if (password_clear := values.get("password_clear")) is None:
raise ValueError("No password to hash")
if (current_config := Config._) is None:
raise ValueError("Not configured")
values["password"] = current_config.crypto.crypt_context.hash(
password_clear)
return values
class UserRead(UserBase):
"""
Representation of a user read via the API
"""
pass
class User(UserBase, table=True):
"""
Representation of `users` table
"""
__tablename__ = "users"
password: str
capabilities: list[UserCapability] = Relationship(
back_populates="user",
sa_relationship_kwargs={
"lazy": "joined",
"cascade": "all, delete-orphan",
},
)
devices: list[Device] = Relationship(
back_populates="owner",
)
@classmethod
def create(cls, **kwargs) -> User | None:
"""
Create a new user in the database.
"""
try:
with Connection.session as db:
user = cls.from_orm(UserCreate(**kwargs))
db.add(user)
db.commit()
db.refresh(user)
return user
except IntegrityError:
# user already existed
return None
@classmethod
def get(cls, name: str) -> User | None:
"""
Load user from database by name.
"""
with Connection.session as db:
return db.get(cls, name)
@classmethod
def authenticate(
cls,
name: str,
password: str,
) -> User | None:
"""
Authenticate with name/password against users in database.
"""
crypt_context = Config._.crypto.crypt_context
if (user := cls.get(name)) is None:
# nonexistent user, fake doing password verification
crypt_context.dummy_verify()
return None
if not crypt_context.verify(password, user.password):
# password hash mismatch
return None
return user
def update(self) -> None:
"""
Update this user in the database.
"""
with Connection.session as db:
db.add(self)
db.commit()
db.refresh(self)
def delete(self) -> None:
"""
Delete this user from the database.
"""
with Connection.session as db:
db.delete(self)
db.commit()
def get_capabilities(self) -> set[Capability]:
"""
Return the capabilities of this user.
"""
return set(
capability._
for capability in self.capabilities
)
def can(self, capability: Capability) -> bool:
"""
Check if this user has a capability.
"""
return capability in self.get_capabilities()
def set_capabilities(self, capabilities: set[Capability]) -> None:
"""
Change the capabilities of this user.
"""
self.capabilities = [
UserCapability(
user_name=self.name,
capability_name=capability.value,
) for capability in capabilities
]

View file

@ -1,58 +0,0 @@
"""
Python representation of `user_capabilities` table.
"""
from enum import Enum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, SQLModel
if TYPE_CHECKING:
from .user import User
class Capability(Enum):
"""
Allowed values for capabilities
"""
admin = "admin"
login = "login"
issue = "issue"
renew = "renew"
def __repr__(self) -> str:
return self.value
class UserCapabilityBase(SQLModel):
"""
Common to all representations of capabilities
"""
capability_name: str = Field(primary_key=True)
@property
def _(self) -> Capability:
"""
Transform into a `Capability`.
"""
return Capability(self.capability_name)
def __repr__(self) -> str:
return self.capability_name
class UserCapability(UserCapabilityBase, table=True):
"""
Representation of `user_capabilities` table
"""
__tablename__ = "user_capabilities"
user_name: str = Field(primary_key=True, foreign_key="users.name")
user: "User" = Relationship(
back_populates="capabilities",
)

View file

@ -13,9 +13,13 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db import Connection, User from .db import Connection
from .db.schemas import User
from .routers import main_router from .routers import main_router
settings = Settings.get()
app = FastAPI( app = FastAPI(
title="kiwi-vpn API", title="kiwi-vpn API",
description="This API enables the `kiwi-vpn` service.", description="This API enables the `kiwi-vpn` service.",
@ -27,24 +31,25 @@ app = FastAPI(
"name": "MIT License", "name": "MIT License",
"url": "https://opensource.org/licenses/mit-license.php", "url": "https://opensource.org/licenses/mit-license.php",
}, },
openapi_url=Settings._.openapi_url, openapi_url=settings.openapi_url,
docs_url=Settings._.docs_url if not Settings._.production_mode else None, docs_url=settings.docs_url if not settings.production_mode else None,
redoc_url=Settings._.redoc_url if not Settings._.production_mode else None, redoc_url=settings.redoc_url if not settings.production_mode else None,
) )
app.include_router(main_router, prefix=f"/{Settings._.api_v1_prefix}") app.include_router(main_router)
@app.on_event("startup") @app.on_event("startup")
async def on_startup() -> None: async def on_startup() -> None:
# check if configured # check if configured
if (current_config := Config._) is not None: if (current_config := await Config.load()) is not None:
# connect to database # connect to database
Connection.connect(current_config.db.uri) Connection.connect(await current_config.db.db_engine)
# some testing # some testing
print(User.get("admin")) with Connection.use() as db:
print(User.get("nonexistent")) print(User.from_db(db, "admin"))
print(User.from_db(db, "nonexistent"))
def main() -> None: def main() -> None:

View file

@ -1,14 +1,8 @@
"""
Package `routers`: Each module contains the path operations for their prefixes.
This file: Main API router definition.
"""
from fastapi import APIRouter from fastapi import APIRouter
from . import admin, user from . import admin, user
main_router = APIRouter() main_router = APIRouter(prefix="/api/v1")
main_router.include_router(admin.router) main_router.include_router(admin.router)
main_router.include_router(user.router) main_router.include_router(user.router)

View file

@ -5,13 +5,13 @@ Common dependencies for routers.
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import Session
from ..config import Config, Settings from ..config import Config
from ..db import Capability, User from ..db import Connection
from ..db.schemas import User
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate")
tokenUrl=f"{Settings._.api_v1_prefix}/user/authenticate"
)
class Responses: class Responses:
@ -56,6 +56,7 @@ class Responses:
async def get_current_user( async def get_current_user(
token: str = Depends(oauth2_scheme), token: str = Depends(oauth2_scheme),
db: Session | None = Depends(Connection.get),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
) -> User | None: ) -> User | None:
""" """
@ -67,11 +68,13 @@ async def get_current_user(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
username = await current_config.jwt.decode_token(token) username = await current_config.jwt.decode_token(token)
user = User.from_db(db, username)
return User.get(username) return user
async def get_current_user_if_exists( async def get_current_user_if_exists(
current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
) -> User: ) -> User:
""" """
@ -86,6 +89,7 @@ async def get_current_user_if_exists(
async def get_current_user_if_admin( async def get_current_user_if_admin(
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -93,7 +97,7 @@ async def get_current_user_if_admin(
""" """
# fail if not requested by an admin # fail if not requested by an admin
if not current_user.can(Capability.admin): if not current_user.is_admin():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user
@ -101,6 +105,7 @@ async def get_current_user_if_admin(
async def get_current_user_if_admin_or_self( async def get_current_user_if_admin_or_self(
user_name: str, user_name: str,
current_config: Config | None = Depends(Config.load),
current_user: User = Depends(get_current_user_if_exists), current_user: User = Depends(get_current_user_if_exists),
) -> User: ) -> User:
""" """
@ -111,8 +116,7 @@ async def get_current_user_if_admin_or_self(
""" """
# fail if not requested by an admin or self # fail if not requested by an admin or self
if not (current_user.can(Capability.admin) if not (current_user.is_admin() or current_user.name == user_name):
or current_user.name == user_name):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return current_user return current_user

View file

@ -4,67 +4,49 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import select
from ..config import Config from ..config import Config
from ..db import Capability, Connection, User, UserCreate from ..db import Connection
from ._common import Responses, get_current_user_if_admin from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_user
router = APIRouter(prefix="/admin", tags=["admin"]) router = APIRouter(prefix="/admin", tags=["admin"])
@router.put( @router.put(
"/install/config", "/install",
responses={ responses={
status.HTTP_200_OK: Responses.OK, status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.INSTALLED, status.HTTP_400_BAD_REQUEST: Responses.INSTALLED,
}, },
) )
async def initial_configure( async def install(
config: Config, config: Config,
current_config: Config | None = Depends(Config.load),
):
"""
PUT ./install/config: Configure `kiwi-vpn`.
"""
# fail if already configured
if current_config is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# create config file, connect to database
config.save()
Connection.connect(current_config.db.uri)
@router.put(
"/install/admin",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
)
async def create_initial_admin(
admin_user: UserCreate, admin_user: UserCreate,
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
): ):
""" """
PUT ./install/admin: Create the first administrative user. PUT ./install: Install `kiwi-vpn`.
""" """
# fail if not configured # fail if already installed
if current_config is None: if current_config is not None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
with Connection.session as db: # create config file, connect to database
if db.exec(select(User).limit(1)).first() is not None: await config.save()
raise HTTPException(status_code=status.HTTP_409_CONFLICT) Connection.connect(await config.db.db_engine)
# create an administrative user # create an administrative user
new_user = User.create(**admin_user.dict()) with Connection.use() as db:
new_user.set_capabilities([Capability.login, Capability.admin]) new_user = User.create(
new_user.update() db=db,
user=admin_user,
crypt_context=await config.crypto.crypt_context,
)
new_user.capabilities.append(UserCapability.admin)
new_user.update(db)
@router.put( @router.put(
@ -77,13 +59,23 @@ async def create_initial_admin(
}, },
) )
async def set_config( async def set_config(
config: Config, new_config: Config,
_: User = Depends(get_current_user_if_admin), current_config: Config | None = Depends(Config.load),
current_user: User | None = Depends(get_current_user),
): ):
""" """
PUT ./config: Edit `kiwi-vpn` main config. PUT ./config: Edit `kiwi-vpn` main config.
""" """
# fail if not installed
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# fail if not requested by an admin
if (current_user is None
or UserCapability.admin not in current_user.capabilities):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
# update config file, reconnect to database # update config file, reconnect to database
config.save() await new_config.save()
Connection.connect(config.db.uri) Connection.connect(await new_config.db.db_engine)

View file

@ -0,0 +1,58 @@
"""
/dn endpoints.
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..db import Connection
from ..db.schemas import DistinguishedName, DistinguishedNameCreate, User
from ._common import Responses, get_current_user_if_admin_or_self
router = APIRouter(prefix="/dn")
@router.post(
"",
responses={
status.HTTP_200_OK: Responses.OK,
status.HTTP_400_BAD_REQUEST: Responses.NOT_INSTALLED,
status.HTTP_401_UNAUTHORIZED: Responses.NEEDS_USER,
status.HTTP_403_FORBIDDEN: Responses.NEEDS_ADMIN,
status.HTTP_404_NOT_FOUND: Responses.ENTRY_DOESNT_EXIST,
status.HTTP_409_CONFLICT: Responses.ENTRY_EXISTS,
},
)
async def add_distinguished_name(
user_name: str,
distinguished_name: DistinguishedNameCreate,
_: User = Depends(get_current_user_if_admin_or_self),
db: Session | None = Depends(Connection.get),
):
"""
POST ./: Create a new distinguished name in the database.
"""
owner = User.from_db(
db=db,
name=user_name,
)
# fail if owner doesn't exist
if owner is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# actually create the new user
new_dn = DistinguishedName.create(
db=db,
dn=distinguished_name,
owner=owner,
)
# fail if creation was unsuccessful
if new_dn is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# return the created user on success
return new_dn

View file

@ -5,9 +5,11 @@
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import Capability, User, UserCreate, UserRead from ..db import Connection
from ..db.schemas import User, UserCapability, UserCreate
from ._common import Responses, get_current_user, get_current_user_if_admin from ._common import Responses, get_current_user, get_current_user_if_admin
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/user", tags=["user"])
@ -26,6 +28,7 @@ class Token(BaseModel):
async def login( async def login(
form_data: OAuth2PasswordRequestForm = Depends(), form_data: OAuth2PasswordRequestForm = Depends(),
current_config: Config | None = Depends(Config.load), current_config: Config | None = Depends(Config.load),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./authenticate: Authenticate a user. Issues a bearer token. POST ./authenticate: Authenticate a user. Issues a bearer token.
@ -36,10 +39,12 @@ async def login(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
# try logging in # try logging in
if not (user := User.authenticate( user = User(name=form_data.username)
name=form_data.username, if not user.authenticate(
db=db,
password=form_data.password, password=form_data.password,
)): crypt_context=await current_config.crypto.crypt_context,
):
# authentication failed # authentication failed
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -52,7 +57,7 @@ async def login(
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}
@router.get("/current", response_model=UserRead) @router.get("/current", response_model=User)
async def get_current_user( async def get_current_user(
current_user: User | None = Depends(get_current_user), current_user: User | None = Depends(get_current_user),
): ):
@ -76,14 +81,20 @@ async def get_current_user(
) )
async def add_user( async def add_user(
user: UserCreate, user: UserCreate,
current_config: Config | None = Depends(Config.load),
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./: Create a new user in the database. POST ./: Create a new user in the database.
""" """
# actually create the new user # actually create the new user
new_user = User.create(**user.dict()) new_user = User.create(
db=db,
user=user,
crypt_context=await current_config.crypto.crypt_context,
)
# fail if creation was unsuccessful # fail if creation was unsuccessful
if new_user is None: if new_user is None:
@ -107,21 +118,22 @@ async def add_user(
async def remove_user( async def remove_user(
user_name: str, user_name: str,
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}: Remove a user from the database. DELETE ./{user_name}: Remove a user from the database.
""" """
# get the user # get the user
user = User.get(user_name) user = User.from_db(
db=db,
name=user_name,
)
# fail if user not found # fail if deletion was unsuccessful
if user is None: if user is None or not user.delete(db):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
# delete user
user.delete()
@router.post( @router.post(
"/{user_name}/capabilities", "/{user_name}/capabilities",
@ -134,21 +146,22 @@ async def remove_user(
) )
async def extend_capabilities( async def extend_capabilities(
user_name: str, user_name: str,
capabilities: list[Capability], capabilities: list[UserCapability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
POST ./{user_name}/capabilities: Add capabilities to a user. POST ./{user_name}/capabilities: Add capabilities to a user.
""" """
# get and change the user # get and change the user
user = User.get(user_name) user = User.from_db(
db=db,
user.set_capabilities( name=user_name,
user.get_capabilities() | set(capabilities)
) )
user.update() user.capabilities.extend(capabilities)
user.update(db)
@router.delete( @router.delete(
@ -162,18 +175,21 @@ async def extend_capabilities(
) )
async def remove_capabilities( async def remove_capabilities(
user_name: str, user_name: str,
capabilities: list[Capability], capabilities: list[UserCapability],
_: User = Depends(get_current_user_if_admin), _: User = Depends(get_current_user_if_admin),
db: Session | None = Depends(Connection.get),
): ):
""" """
DELETE ./{user_name}/capabilities: Remove capabilities from a user. DELETE ./{user_name}/capabilities: Remove capabilities from a user.
""" """
# get and change the user # get and change the user
user = User.get(user_name) user = User.from_db(
db=db,
user.set_capabilities( name=user_name,
user.get_capabilities() - set(capabilities)
) )
user.update() for capability in capabilities:
user.capabilities.remove(capability)
user.update(db)

View file

@ -1,25 +0,0 @@
## Server props
- default DN parts: country, state, city, org, OU
- "customizable" flags for DN parts
- flag: use client-to-client
- force cipher, tls-cipher, auth params
- server name
- default certification length
- default certificate algo
## User props
- username
- password
- custom DN parts: country, state, city, org, OU
- email
## User caps
- admin: administrator
- login: can log into the web interface
- issue: can certify own devices without approval
- renew: can renew certificates for own devices
## Device props
- name
- type (icon)
- expiry

34
api/poetry.lock generated
View file

@ -428,30 +428,6 @@ postgresql_psycopg2cffi = ["psycopg2cffi"]
pymysql = ["pymysql (<1)", "pymysql"] pymysql = ["pymysql (<1)", "pymysql"]
sqlcipher = ["sqlcipher3-binary"] sqlcipher = ["sqlcipher3-binary"]
[[package]]
name = "sqlalchemy2-stubs"
version = "0.0.2a21"
description = "Typing Stubs for SQLAlchemy 1.4"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
typing-extensions = ">=3.7.4"
[[package]]
name = "sqlmodel"
version = "0.0.6"
description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness."
category = "main"
optional = false
python-versions = ">=3.6.1,<4.0.0"
[package.dependencies]
pydantic = ">=1.8.2,<2.0.0"
SQLAlchemy = ">=1.4.17,<1.5.0"
sqlalchemy2-stubs = "*"
[[package]] [[package]]
name = "starlette" name = "starlette"
version = "0.17.1" version = "0.17.1"
@ -501,7 +477,7 @@ standard = ["websockets (>=10.0)", "httptools (>=0.4.0)", "watchgod (>=0.6)", "p
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "a580f9fe4c68667e4cbdf385ac11d5c7a2925e3c990b7164faa922ec8b6f9555" content-hash = "432d2933102f8a0091cec1b5484944a0211ca74c5dc9b65877d99d7bd160e4bb"
[metadata.files] [metadata.files]
anyio = [ anyio = [
@ -858,14 +834,6 @@ sqlalchemy = [
{file = "SQLAlchemy-1.4.32-cp39-cp39-win_amd64.whl", hash = "sha256:b3f1d9b3aa09ab9adc7f8c4b40fc3e081eb903054c9a6f9ae1633fe15ae503b4"}, {file = "SQLAlchemy-1.4.32-cp39-cp39-win_amd64.whl", hash = "sha256:b3f1d9b3aa09ab9adc7f8c4b40fc3e081eb903054c9a6f9ae1633fe15ae503b4"},
{file = "SQLAlchemy-1.4.32.tar.gz", hash = "sha256:6fdd2dc5931daab778c2b65b03df6ae68376e028a3098eb624d0909d999885bc"}, {file = "SQLAlchemy-1.4.32.tar.gz", hash = "sha256:6fdd2dc5931daab778c2b65b03df6ae68376e028a3098eb624d0909d999885bc"},
] ]
sqlalchemy2-stubs = [
{file = "sqlalchemy2-stubs-0.0.2a21.tar.gz", hash = "sha256:207e3d8a36fc032d325f4eec89e0c6760efe81d07e978513d8c9b14f108dcd0c"},
{file = "sqlalchemy2_stubs-0.0.2a21-py3-none-any.whl", hash = "sha256:bd4a3d5ca7ff9d01b2245e1b26304d6b2ec4daf43a01faf40db9e09245679433"},
]
sqlmodel = [
{file = "sqlmodel-0.0.6-py3-none-any.whl", hash = "sha256:c5fd8719e09da348cd32ce2a5b6a44f289d3029fa8f1c9818229b6f34f1201b4"},
{file = "sqlmodel-0.0.6.tar.gz", hash = "sha256:3b4f966b9671b24d85529d274e6c4dbc7753b468e35d2d6a40bd75cad1f66813"},
]
starlette = [ starlette = [
{file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"}, {file = "starlette-0.17.1-py3-none-any.whl", hash = "sha256:26a18cbda5e6b651c964c12c88b36d9898481cd428ed6e063f5f29c418f73050"},
{file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"}, {file = "starlette-0.17.1.tar.gz", hash = "sha256:57eab3cc975a28af62f6faec94d355a410634940f10b30d68d31cb5ec1b44ae8"},

View file

@ -13,7 +13,6 @@ python-jose = {extras = ["cryptography"], version = "^3.3.0"}
passlib = {extras = ["argon2", "bcrypt"], version = "^1.7.4"} passlib = {extras = ["argon2", "bcrypt"], version = "^1.7.4"}
SQLAlchemy = "^1.4.32" SQLAlchemy = "^1.4.32"
pyOpenSSL = "^22.0.0" pyOpenSSL = "^22.0.0"
sqlmodel = "^0.0.6"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
pytest = "^7.1.0" pytest = "^7.1.0"