Compare commits

..

No commits in common. "008f0b2cf6235ad8dfd9d57b8d92e36efc2dff97" and "69b0a619e041fb626588cb824f6c99e90685b5ab" have entirely different histories.

9 changed files with 33 additions and 39 deletions

View file

@ -73,7 +73,7 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db")) database: str | None = Settings._.data_dir.joinpath("kiwi-vpn.db")
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@ -99,8 +99,6 @@ class DBConfig(BaseModel):
f"{self.user}:{self.password}@{self.host}" f"{self.user}:{self.password}@{self.host}"
f"/{self.database}{args_str}") f"/{self.database}{args_str}")
return ""
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
""" """
@ -172,7 +170,11 @@ class JWTConfig(BaseModel):
return None return None
# get username # get username
return payload.get("sub") username = payload.get("sub")
if username is None:
return None
return username
class LockableString(BaseModel): class LockableString(BaseModel):
@ -189,7 +191,7 @@ class LockableCountry(LockableString):
Like `LockableString`, but with a `value` constrained two characters Like `LockableString`, but with a `value` constrained two characters
""" """
value: constr(max_length=2) # type: ignore value: constr(max_length=2)
class ServerDN(BaseModel): class ServerDN(BaseModel):
@ -274,15 +276,12 @@ class Config(BaseModel):
@classmethod @classmethod
@property @property
def _(cls) -> Config: def _(cls) -> Config | None:
""" """
Shorthand for load(), but config file must exist Shorthand for load()
""" """
if (config := cls.load()) is None: return cls.load()
raise FileNotFoundError(Settings._.config_file)
return config
def save(self) -> None: def save(self) -> None:
""" """

View file

@ -23,12 +23,12 @@ class Connection:
@classmethod @classmethod
@property @property
def session(cls) -> Session: def session(cls) -> Session | None:
""" """
Create an ORM session using a context manager. Create an ORM session using a context manager.
""" """
if cls.engine is None: if cls.engine is None:
raise ValueError("Not connected to database, can't create session") return None
return Session(cls.engine) return Session(cls.engine)

View file

@ -106,7 +106,7 @@ class Device(DeviceBase, table=True):
db.commit() db.commit()
db.refresh(self) db.refresh(self)
def delete(self) -> None: def delete(self) -> bool:
""" """
Delete this device from the database. Delete this device from the database.
""" """

View file

@ -32,7 +32,7 @@ class TagValue(Enum):
""" """
return Tag( return Tag(
user_name=user.name, user=user,
tag_value=self.value, tag_value=self.value,
) )

View file

@ -22,7 +22,7 @@ class UserBase(SQLModel):
""" """
name: str = Field(primary_key=True) name: str = Field(primary_key=True)
email: str email: str | None = Field(default=None)
country: str | None = Field(default=None, max_length=2) country: str | None = Field(default=None, max_length=2)
state: str | None = Field(default=None) state: str | None = Field(default=None)
@ -53,7 +53,7 @@ class UserCreate(UserBase):
if (password_clear := values.get("password_clear")) is None: if (password_clear := values.get("password_clear")) is None:
raise ValueError("No password to hash") raise ValueError("No password to hash")
if (current_config := Config.load()) is None: if (current_config := Config._) is None:
raise ValueError("Not configured") raise ValueError("Not configured")
values["password"] = current_config.crypto.context.hash( values["password"] = current_config.crypto.context.hash(
@ -225,8 +225,8 @@ class User(UserBase, table=True):
return True return True
# user can "edit" itself # user can "edit" itself
if isinstance(target, User): if isinstance(target, User) and target != self:
return target == self return False
# user can edit its owned devices # user can edit its owned devices
return target.owner == self return target.owner == self

View file

@ -79,12 +79,12 @@ class DistinguishedName(BaseModel):
return result return result
@property @property
def easyrsa_args(self) -> list[str]: def easyrsa_args(self) -> tuple[str]:
""" """
Pass this DN as arguments to easyrsa Pass this DN as arguments to easyrsa
""" """
return [ return (
"--dn-mode=org", "--dn-mode=org",
f"--req-c={self.country}", f"--req-c={self.country}",
@ -94,7 +94,7 @@ class DistinguishedName(BaseModel):
f"--req-ou={self.organizational_unit}", f"--req-ou={self.organizational_unit}",
f"--req-email={self.email}", f"--req-email={self.email}",
f"--req-cn={self.common_name}", f"--req-cn={self.common_name}",
] )
class EasyRSA: class EasyRSA:
@ -159,13 +159,13 @@ class EasyRSA:
config = Config._ config = Config._
extra_args: list[str] = [ extra_args: tuple[str] = (
f"--passout=pass:{self.ca_password}", f"--passout=pass:{self.ca_password}",
f"--passin=pass:{self.ca_password}", f"--passin=pass:{self.ca_password}",
] )
if expiry_days is not None: if expiry_days is not None:
extra_args += [f"--days={expiry_days}"] extra_args += tuple([f"--days={expiry_days}"])
if (algo := config.crypto.cert_algo) is not None: if (algo := config.crypto.cert_algo) is not None:
if algo is CertificateAlgo.rsa2048: if algo is CertificateAlgo.rsa2048:
@ -192,13 +192,13 @@ class EasyRSA:
) )
with open( with open(
self.output_directory.joinpath(cert_filename), "rb" self.output_directory.joinpath(cert_filename), "r"
) as cert_file: ) as cert_file:
return crypto.load_certificate( return crypto.load_certificate(
crypto.FILETYPE_PEM, cert_file.read() crypto.FILETYPE_PEM, cert_file.read()
) )
def init_pki(self) -> None: def init_pki(self) -> bool:
""" """
Clean the working directory Clean the working directory
""" """
@ -255,7 +255,7 @@ if __name__ == "__main__":
client = None client = None
# check if configured # check if configured
if (current_config := Config.load()) is not None: if (current_config := Config._) is not None:
# connect to database # connect to database
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)
@ -272,6 +272,5 @@ if __name__ == "__main__":
if cert is not None: if cert is not None:
print(cert.get_subject().CN) print(cert.get_subject().CN)
print(cert.get_signature_algorithm().decode(encoding)) print(cert.get_signature_algorithm().decode(encoding))
print(datetime.strptime(
assert (na := cert.get_notAfter()) is not None cert.get_notAfter().decode(encoding), date_format))
print(datetime.strptime(na.decode(encoding), date_format))

View file

@ -40,7 +40,7 @@ app.include_router(main_router)
@app.on_event("startup") @app.on_event("startup")
async def on_startup() -> None: async def on_startup() -> None:
# check if configured # check if configured
if (current_config := Config.load()) is not None: if (current_config := Config._) is not None:
# connect to database # connect to database
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)

View file

@ -71,13 +71,11 @@ async def get_current_user(
Get the currently logged-in user if it exists. Get the currently logged-in user if it exists.
""" """
# don't use error 404 here - possible user enumeration username = await current_config.jwt.decode_token(token)
# fail if not requested by a user # fail if not requested by a user
if (username := await current_config.jwt.decode_token(token)) is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if (user := User.get(username)) is None: if (user := User.get(username)) is None:
# don't use error 404 here: possible user enumeration
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return user return user

View file

@ -58,9 +58,7 @@ async def create_initial_admin(
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# create an administrative user # create an administrative user
if (new_user := User.create(user=admin_user)) is None: new_user = User.create(user=admin_user)
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
new_user.add_tags([TagValue.admin]) new_user.add_tags([TagValue.admin])
new_user.update() new_user.update()