diff --git a/api/kiwi_vpn_api/db/__init__.py b/api/kiwi_vpn_api/db/__init__.py index 0e8041b..80b582e 100644 --- a/api/kiwi_vpn_api/db/__init__.py +++ b/api/kiwi_vpn_api/db/__init__.py @@ -1,4 +1,4 @@ -from . import models, schemas +from . import models, schemata from .connection import Connection -__all__ = ["Connection", "models", "schemas"] +__all__ = ["Connection", "models", "schemata"] diff --git a/api/kiwi_vpn_api/db/schemata/__init__.py b/api/kiwi_vpn_api/db/schemata/__init__.py new file mode 100644 index 0000000..5a991aa --- /dev/null +++ b/api/kiwi_vpn_api/db/schemata/__init__.py @@ -0,0 +1,6 @@ +from .device import Device, DeviceBase, DeviceCreate +from .user import User, UserBase, UserCreate +from .user_capability import UserCapability + +__all__ = ["Device", "DeviceBase", "DeviceCreate", + "User", "UserBase", "UserCreate", "UserCapability"] diff --git a/api/kiwi_vpn_api/db/schemata/device.py b/api/kiwi_vpn_api/db/schemata/device.py new file mode 100644 index 0000000..8e47f7e --- /dev/null +++ b/api/kiwi_vpn_api/db/schemata/device.py @@ -0,0 +1,79 @@ +""" +Pydantic representation of database contents. +""" + +from __future__ import annotations + +from datetime import datetime + +from pydantic import BaseModel +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from .. import models + + +class DeviceBase(BaseModel): + name: str + type: str + expiry: datetime + + +class DeviceCreate(DeviceBase): + owner_name: str + + +class Device(DeviceBase): + class Config: + orm_mode = True + + @classmethod + def create( + cls, + db: Session, + device: DeviceCreate, + ) -> Device | None: + """ + Create a new device in the database. + """ + + try: + db_device = models.Device( + owner_name=device.owner_name, + + name=device.name, + type=device.type, + expiry=device.expiry, + ) + + db.add(db_device) + db.commit() + db.refresh(db_device) + + return cls.from_orm(db_device) + + except IntegrityError: + # device already existed + return None + + def delete( + self, + db: Session, + ) -> bool: + """ + Delete this device from the database. + """ + + db_device = models.Device( + # owner_name= + name=self.name, + ) + db.refresh(db_device) + + if db_device is None: + # nonexistent device + return False + + db.delete(db_device) + db.commit() + return True diff --git a/api/kiwi_vpn_api/db/schemas.py b/api/kiwi_vpn_api/db/schemata/user.py similarity index 58% rename from api/kiwi_vpn_api/db/schemas.py rename to api/kiwi_vpn_api/db/schemata/user.py index 7fe003f..5789f05 100644 --- a/api/kiwi_vpn_api/db/schemas.py +++ b/api/kiwi_vpn_api/db/schemata/user.py @@ -2,62 +2,18 @@ Pydantic representation of database contents. """ - from __future__ import annotations -from datetime import datetime -from enum import Enum from typing import Any from passlib.context import CryptContext from pydantic import BaseModel, Field, validator -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, InvalidRequestError from sqlalchemy.orm import Session -from . import models - -########## -# table: user_capabilities -########## - - -class UserCapability(Enum): - admin = "admin" - login = "login" - issue = "issue" - renew = "renew" - - def __repr__(self) -> str: - return self.value - - @classmethod - def from_value(cls, value) -> UserCapability: - """ - Create UserCapability from various formats - """ - - if isinstance(value, cls): - # value is already a UserCapability, use that - return value - - elif isinstance(value, models.UserCapability): - # create from db format - return cls(value.capability) - - else: - # create from string representation - return cls(str(value)) - - @property - def model(self) -> models.UserCapability: - return models.UserCapability( - capability=self.value, - ) - - -########## -# table: users -########## +from .. import models +from .device import Device +from .user_capability import UserCapability class UserBase(BaseModel): @@ -71,14 +27,14 @@ class UserBase(BaseModel): email: str - capabilities: list[UserCapability] = [] - class UserCreate(UserBase): password: str class User(UserBase): + capabilities: list[UserCapability] = [] + devices: list[Device] = Field( default=[], repr=False ) @@ -108,10 +64,14 @@ class User(UserBase): Load user from database by name. """ - db_user = models.User(name=name) - db.refresh(db_user) + try: + db_user = models.User(name=name) + db.refresh(db_user) - return cls.from_orm(db_user) + return cls.from_orm(db_user) + + except InvalidRequestError: + return None @classmethod def create( @@ -142,7 +102,7 @@ class User(UserBase): except IntegrityError: # user already existed - pass + return None def is_admin(self) -> bool: return UserCapability.admin in self.capabilities @@ -212,74 +172,3 @@ class User(UserBase): db.delete(db_user) db.commit() return True - - -########## -# table: devices -########## - - -class DeviceBase(BaseModel): - name: str - type: str - expiry: datetime - - -class DeviceCreate(DeviceBase): - owner_name: str - - -class Device(DeviceBase): - class Config: - orm_mode = True - - @classmethod - def create( - cls, - db: Session, - device: DeviceCreate, - ) -> Device | None: - """ - Create a new device in the database. - """ - - try: - db_device = models.Device( - owner_name=device.owner_name, - - name=device.name, - type=device.type, - expiry=device.expiry, - ) - - db.add(db_device) - db.commit() - db.refresh(db_device) - - return cls.from_orm(db_device) - - except IntegrityError: - # device already existed - pass - - def delete( - self, - db: Session, - ) -> bool: - """ - Delete this device from the database. - """ - - db_device = models.Device( - # owner_name= - name=self.name, - ) - db.refresh(db_device) - - if db_device is None: - # nonexistent device - return False - - db.delete(db_device) - db.commit() - return True diff --git a/api/kiwi_vpn_api/db/schemata/user_capability.py b/api/kiwi_vpn_api/db/schemata/user_capability.py new file mode 100644 index 0000000..28e1870 --- /dev/null +++ b/api/kiwi_vpn_api/db/schemata/user_capability.py @@ -0,0 +1,43 @@ +""" +Pydantic representation of database contents. +""" + +from __future__ import annotations + +from enum import Enum + +from .. import models + + +class UserCapability(Enum): + admin = "admin" + login = "login" + issue = "issue" + renew = "renew" + + def __repr__(self) -> str: + return self.value + + @classmethod + def from_value(cls, value) -> UserCapability: + """ + Create UserCapability from various formats + """ + + if isinstance(value, cls): + # value is already a UserCapability, use that + return value + + elif isinstance(value, models.UserCapability): + # create from db format + return cls(value.capability) + + else: + # create from string representation + return cls(str(value)) + + @property + def model(self) -> models.UserCapability: + return models.UserCapability( + capability=self.value, + ) diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index 82c2fa8..9ed5b77 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -14,7 +14,7 @@ from fastapi import FastAPI from .config import Config, Settings from .db import Connection -from .db.schemas import User +from .db.schemata import User from .routers import main_router settings = Settings.get() diff --git a/api/kiwi_vpn_api/routers/_common.py b/api/kiwi_vpn_api/routers/_common.py index f9424bc..d7f3969 100644 --- a/api/kiwi_vpn_api/routers/_common.py +++ b/api/kiwi_vpn_api/routers/_common.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from ..config import Config from ..db import Connection -from ..db.schemas import User +from ..db.schemata import User oauth2_scheme = OAuth2PasswordBearer(tokenUrl="user/authenticate") diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index f3e6daf..9422eff 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from ..config import Config from ..db import Connection -from ..db.schemas import User, UserCapability, UserCreate +from ..db.schemata import User, UserCapability, UserCreate from ._common import Responses, get_current_user router = APIRouter(prefix="/admin", tags=["admin"]) diff --git a/api/kiwi_vpn_api/routers/dn.py b/api/kiwi_vpn_api/routers/dn.py index bdebb9c..80ed087 100644 --- a/api/kiwi_vpn_api/routers/dn.py +++ b/api/kiwi_vpn_api/routers/dn.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from ..db import Connection -from ..db.schemas import DistinguishedName, DistinguishedNameCreate, User +from ..db.schemata import DistinguishedName, DistinguishedNameCreate, User from ._common import Responses, get_current_user_if_admin_or_self router = APIRouter(prefix="/dn") diff --git a/api/kiwi_vpn_api/routers/user.py b/api/kiwi_vpn_api/routers/user.py index 63e2b22..627c5fc 100644 --- a/api/kiwi_vpn_api/routers/user.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from ..config import Config from ..db import Connection -from ..db.schemas import User, UserCapability, UserCreate +from ..db.schemata import User, UserCapability, UserCreate from ._common import Responses, get_current_user, get_current_user_if_admin router = APIRouter(prefix="/user", tags=["user"])