resolved all warnings
This commit is contained in:
parent
69b0a619e0
commit
5d0d996288
9 changed files with 39 additions and 33 deletions
|
@ -73,7 +73,7 @@ class DBConfig(BaseModel):
|
|||
user: str | None = None
|
||||
password: str | None = None
|
||||
host: str | None = None
|
||||
database: str | None = Settings._.data_dir.joinpath("kiwi-vpn.db")
|
||||
database: str | None = str(Settings._.data_dir.joinpath("kiwi-vpn.db"))
|
||||
|
||||
mysql_driver: str = "pymysql"
|
||||
mysql_args: list[str] = ["charset=utf8mb4"]
|
||||
|
@ -99,6 +99,8 @@ class DBConfig(BaseModel):
|
|||
f"{self.user}:{self.password}@{self.host}"
|
||||
f"/{self.database}{args_str}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class JWTConfig(BaseModel):
|
||||
"""
|
||||
|
@ -170,11 +172,7 @@ class JWTConfig(BaseModel):
|
|||
return None
|
||||
|
||||
# get username
|
||||
username = payload.get("sub")
|
||||
if username is None:
|
||||
return None
|
||||
|
||||
return username
|
||||
return payload.get("sub")
|
||||
|
||||
|
||||
class LockableString(BaseModel):
|
||||
|
@ -191,7 +189,7 @@ class LockableCountry(LockableString):
|
|||
Like `LockableString`, but with a `value` constrained two characters
|
||||
"""
|
||||
|
||||
value: constr(max_length=2)
|
||||
value: constr(max_length=2) # type: ignore
|
||||
|
||||
|
||||
class ServerDN(BaseModel):
|
||||
|
@ -276,12 +274,15 @@ class Config(BaseModel):
|
|||
|
||||
@classmethod
|
||||
@property
|
||||
def _(cls) -> Config | None:
|
||||
def _(cls) -> Config:
|
||||
"""
|
||||
Shorthand for load()
|
||||
Shorthand for load(), but config file must exist
|
||||
"""
|
||||
|
||||
return cls.load()
|
||||
if (config := cls.load()) is None:
|
||||
raise FileNotFoundError(Settings._.config_file)
|
||||
|
||||
return config
|
||||
|
||||
def save(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -23,12 +23,12 @@ class Connection:
|
|||
|
||||
@classmethod
|
||||
@property
|
||||
def session(cls) -> Session | None:
|
||||
def session(cls) -> Session:
|
||||
"""
|
||||
Create an ORM session using a context manager.
|
||||
"""
|
||||
|
||||
if cls.engine is None:
|
||||
return None
|
||||
raise ValueError("Not connected to database, can't create session")
|
||||
|
||||
return Session(cls.engine)
|
||||
|
|
|
@ -106,7 +106,7 @@ class Device(DeviceBase, table=True):
|
|||
db.commit()
|
||||
db.refresh(self)
|
||||
|
||||
def delete(self) -> bool:
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete this device from the database.
|
||||
"""
|
||||
|
|
|
@ -32,7 +32,7 @@ class TagValue(Enum):
|
|||
"""
|
||||
|
||||
return Tag(
|
||||
user=user,
|
||||
user_name=user.name,
|
||||
tag_value=self.value,
|
||||
)
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ class UserCreate(UserBase):
|
|||
if (password_clear := values.get("password_clear")) is None:
|
||||
raise ValueError("No password to hash")
|
||||
|
||||
if (current_config := Config._) is None:
|
||||
if (current_config := Config.load()) is None:
|
||||
raise ValueError("Not configured")
|
||||
|
||||
values["password"] = current_config.crypto.context.hash(
|
||||
|
@ -225,8 +225,8 @@ class User(UserBase, table=True):
|
|||
return True
|
||||
|
||||
# user can "edit" itself
|
||||
if isinstance(target, User) and target != self:
|
||||
return False
|
||||
if isinstance(target, User):
|
||||
return target == self
|
||||
|
||||
# user can edit its owned devices
|
||||
return target.owner == self
|
||||
|
|
|
@ -26,7 +26,7 @@ class DistinguishedName(BaseModel):
|
|||
city: str
|
||||
organization: str
|
||||
organizational_unit: str
|
||||
email: str
|
||||
email: str | None
|
||||
common_name: str
|
||||
|
||||
@classmethod
|
||||
|
@ -79,12 +79,12 @@ class DistinguishedName(BaseModel):
|
|||
return result
|
||||
|
||||
@property
|
||||
def easyrsa_args(self) -> tuple[str]:
|
||||
def easyrsa_args(self) -> list[str]:
|
||||
"""
|
||||
Pass this DN as arguments to easyrsa
|
||||
"""
|
||||
|
||||
return (
|
||||
return [
|
||||
"--dn-mode=org",
|
||||
|
||||
f"--req-c={self.country}",
|
||||
|
@ -94,7 +94,7 @@ class DistinguishedName(BaseModel):
|
|||
f"--req-ou={self.organizational_unit}",
|
||||
f"--req-email={self.email}",
|
||||
f"--req-cn={self.common_name}",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class EasyRSA:
|
||||
|
@ -159,13 +159,13 @@ class EasyRSA:
|
|||
|
||||
config = Config._
|
||||
|
||||
extra_args: tuple[str] = (
|
||||
extra_args: list[str] = [
|
||||
f"--passout=pass:{self.ca_password}",
|
||||
f"--passin=pass:{self.ca_password}",
|
||||
)
|
||||
]
|
||||
|
||||
if expiry_days is not None:
|
||||
extra_args += tuple([f"--days={expiry_days}"])
|
||||
extra_args += [f"--days={expiry_days}"]
|
||||
|
||||
if (algo := config.crypto.cert_algo) is not None:
|
||||
if algo is CertificateAlgo.rsa2048:
|
||||
|
@ -192,13 +192,13 @@ class EasyRSA:
|
|||
)
|
||||
|
||||
with open(
|
||||
self.output_directory.joinpath(cert_filename), "r"
|
||||
self.output_directory.joinpath(cert_filename), "rb"
|
||||
) as cert_file:
|
||||
return crypto.load_certificate(
|
||||
crypto.FILETYPE_PEM, cert_file.read()
|
||||
)
|
||||
|
||||
def init_pki(self) -> bool:
|
||||
def init_pki(self) -> None:
|
||||
"""
|
||||
Clean the working directory
|
||||
"""
|
||||
|
@ -255,7 +255,7 @@ if __name__ == "__main__":
|
|||
client = None
|
||||
|
||||
# check if configured
|
||||
if (current_config := Config._) is not None:
|
||||
if (current_config := Config.load()) is not None:
|
||||
# connect to database
|
||||
Connection.connect(current_config.db.uri)
|
||||
|
||||
|
@ -272,5 +272,6 @@ if __name__ == "__main__":
|
|||
if cert is not None:
|
||||
print(cert.get_subject().CN)
|
||||
print(cert.get_signature_algorithm().decode(encoding))
|
||||
print(datetime.strptime(
|
||||
cert.get_notAfter().decode(encoding), date_format))
|
||||
|
||||
assert (na := cert.get_notAfter()) is not None
|
||||
print(datetime.strptime(na.decode(encoding), date_format))
|
||||
|
|
|
@ -40,7 +40,7 @@ app.include_router(main_router)
|
|||
@app.on_event("startup")
|
||||
async def on_startup() -> None:
|
||||
# check if configured
|
||||
if (current_config := Config._) is not None:
|
||||
if (current_config := Config.load()) is not None:
|
||||
# connect to database
|
||||
Connection.connect(current_config.db.uri)
|
||||
|
||||
|
|
|
@ -71,11 +71,13 @@ async def get_current_user(
|
|||
Get the currently logged-in user if it exists.
|
||||
"""
|
||||
|
||||
username = await current_config.jwt.decode_token(token)
|
||||
# don't use error 404 here - possible user enumeration
|
||||
|
||||
# fail if not requested by a user
|
||||
if (username := await current_config.jwt.decode_token(token)) is None:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
if (user := User.get(username)) is None:
|
||||
# don't use error 404 here: possible user enumeration
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
return user
|
||||
|
|
|
@ -58,7 +58,9 @@ async def create_initial_admin(
|
|||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
# create an administrative user
|
||||
new_user = User.create(user=admin_user)
|
||||
if (new_user := User.create(user=admin_user)) is None:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT)
|
||||
|
||||
new_user.add_tags([TagValue.admin])
|
||||
new_user.update()
|
||||
|
||||
|
|
Loading…
Reference in a new issue