resolved all warnings

This commit is contained in:
Jörn-Michael Miehe 2022-03-31 16:32:07 +00:00
parent 69b0a619e0
commit 5d0d996288
9 changed files with 39 additions and 33 deletions

View file

@ -73,7 +73,7 @@ class DBConfig(BaseModel):
user: str | None = None user: str | None = None
password: str | None = None password: str | None = None
host: str | None = None host: str | None = None
database: str | None = Settings._.data_dir.joinpath("kiwi-vpn.db") database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db"))
mysql_driver: str = "pymysql" mysql_driver: str = "pymysql"
mysql_args: list[str] = ["charset=utf8mb4"] mysql_args: list[str] = ["charset=utf8mb4"]
@ -99,6 +99,8 @@ class DBConfig(BaseModel):
f"{self.user}:{self.password}@{self.host}" f"{self.user}:{self.password}@{self.host}"
f"/{self.database}{args_str}") f"/{self.database}{args_str}")
return ""
class JWTConfig(BaseModel): class JWTConfig(BaseModel):
""" """
@ -170,11 +172,7 @@ class JWTConfig(BaseModel):
return None return None
# get username # get username
username = payload.get("sub") return payload.get("sub")
if username is None:
return None
return username
class LockableString(BaseModel): class LockableString(BaseModel):
@ -191,7 +189,7 @@ class LockableCountry(LockableString):
Like `LockableString`, but with a `value` constrained two characters Like `LockableString`, but with a `value` constrained two characters
""" """
value: constr(max_length=2) value: constr(max_length=2) # type: ignore
class ServerDN(BaseModel): class ServerDN(BaseModel):
@ -276,12 +274,15 @@ class Config(BaseModel):
@classmethod @classmethod
@property @property
def _(cls) -> Config | None: def _(cls) -> Config:
""" """
Shorthand for load() Shorthand for load(), but config file must exist
""" """
return cls.load() if (config := cls.load()) is None:
raise FileNotFoundError(Settings._.config_file)
return config
def save(self) -> None: def save(self) -> None:
""" """

View file

@ -23,12 +23,12 @@ class Connection:
@classmethod @classmethod
@property @property
def session(cls) -> Session | None: def session(cls) -> Session:
""" """
Create an ORM session using a context manager. Create an ORM session using a context manager.
""" """
if cls.engine is None: if cls.engine is None:
return None raise ValueError("Not connected to database, can't create session")
return Session(cls.engine) return Session(cls.engine)

View file

@ -106,7 +106,7 @@ class Device(DeviceBase, table=True):
db.commit() db.commit()
db.refresh(self) db.refresh(self)
def delete(self) -> bool: def delete(self) -> None:
""" """
Delete this device from the database. Delete this device from the database.
""" """

View file

@ -32,7 +32,7 @@ class TagValue(Enum):
""" """
return Tag( return Tag(
user=user, user_name=user.name,
tag_value=self.value, tag_value=self.value,
) )

View file

@ -53,7 +53,7 @@ class UserCreate(UserBase):
if (password_clear := values.get("password_clear")) is None: if (password_clear := values.get("password_clear")) is None:
raise ValueError("No password to hash") raise ValueError("No password to hash")
if (current_config := Config._) is None: if (current_config := Config.load()) is None:
raise ValueError("Not configured") raise ValueError("Not configured")
values["password"] = current_config.crypto.context.hash( values["password"] = current_config.crypto.context.hash(
@ -225,8 +225,8 @@ class User(UserBase, table=True):
return True return True
# user can "edit" itself # user can "edit" itself
if isinstance(target, User) and target != self: if isinstance(target, User):
return False return target == self
# user can edit its owned devices # user can edit its owned devices
return target.owner == self return target.owner == self

View file

@ -26,7 +26,7 @@ class DistinguishedName(BaseModel):
city: str city: str
organization: str organization: str
organizational_unit: str organizational_unit: str
email: str email: str | None
common_name: str common_name: str
@classmethod @classmethod
@ -79,12 +79,12 @@ class DistinguishedName(BaseModel):
return result return result
@property @property
def easyrsa_args(self) -> tuple[str]: def easyrsa_args(self) -> list[str]:
""" """
Pass this DN as arguments to easyrsa Pass this DN as arguments to easyrsa
""" """
return ( return [
"--dn-mode=org", "--dn-mode=org",
f"--req-c={self.country}", f"--req-c={self.country}",
@ -94,7 +94,7 @@ class DistinguishedName(BaseModel):
f"--req-ou={self.organizational_unit}", f"--req-ou={self.organizational_unit}",
f"--req-email={self.email}", f"--req-email={self.email}",
f"--req-cn={self.common_name}", f"--req-cn={self.common_name}",
) ]
class EasyRSA: class EasyRSA:
@ -159,13 +159,13 @@ class EasyRSA:
config = Config._ config = Config._
extra_args: tuple[str] = ( extra_args: list[str] = [
f"--passout=pass:{self.ca_password}", f"--passout=pass:{self.ca_password}",
f"--passin=pass:{self.ca_password}", f"--passin=pass:{self.ca_password}",
) ]
if expiry_days is not None: if expiry_days is not None:
extra_args += tuple([f"--days={expiry_days}"]) extra_args += [f"--days={expiry_days}"]
if (algo := config.crypto.cert_algo) is not None: if (algo := config.crypto.cert_algo) is not None:
if algo is CertificateAlgo.rsa2048: if algo is CertificateAlgo.rsa2048:
@ -192,13 +192,13 @@ class EasyRSA:
) )
with open( with open(
self.output_directory.joinpath(cert_filename), "r" self.output_directory.joinpath(cert_filename), "rb"
) as cert_file: ) as cert_file:
return crypto.load_certificate( return crypto.load_certificate(
crypto.FILETYPE_PEM, cert_file.read() crypto.FILETYPE_PEM, cert_file.read()
) )
def init_pki(self) -> bool: def init_pki(self) -> None:
""" """
Clean the working directory Clean the working directory
""" """
@ -255,7 +255,7 @@ if __name__ == "__main__":
client = None client = None
# check if configured # check if configured
if (current_config := Config._) is not None: if (current_config := Config.load()) is not None:
# connect to database # connect to database
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)
@ -272,5 +272,6 @@ if __name__ == "__main__":
if cert is not None: if cert is not None:
print(cert.get_subject().CN) print(cert.get_subject().CN)
print(cert.get_signature_algorithm().decode(encoding)) print(cert.get_signature_algorithm().decode(encoding))
print(datetime.strptime(
cert.get_notAfter().decode(encoding), date_format)) assert (na := cert.get_notAfter()) is not None
print(datetime.strptime(na.decode(encoding), date_format))

View file

@ -40,7 +40,7 @@ app.include_router(main_router)
@app.on_event("startup") @app.on_event("startup")
async def on_startup() -> None: async def on_startup() -> None:
# check if configured # check if configured
if (current_config := Config._) is not None: if (current_config := Config.load()) is not None:
# connect to database # connect to database
Connection.connect(current_config.db.uri) Connection.connect(current_config.db.uri)

View file

@ -71,11 +71,13 @@ async def get_current_user(
Get the currently logged-in user if it exists. Get the currently logged-in user if it exists.
""" """
username = await current_config.jwt.decode_token(token) # don't use error 404 here - possible user enumeration
# fail if not requested by a user # fail if not requested by a user
if (username := await current_config.jwt.decode_token(token)) is None:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
if (user := User.get(username)) is None: if (user := User.get(username)) is None:
# don't use error 404 here: possible user enumeration
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
return user return user

View file

@ -58,7 +58,9 @@ async def create_initial_admin(
raise HTTPException(status_code=status.HTTP_409_CONFLICT) raise HTTPException(status_code=status.HTTP_409_CONFLICT)
# create an administrative user # create an administrative user
new_user = User.create(user=admin_user) if (new_user := User.create(user=admin_user)) is None:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
new_user.add_tags([TagValue.admin]) new_user.add_tags([TagValue.admin])
new_user.update() new_user.update()