type checking "basic"

This commit is contained in:
Jörn-Michael Miehe 2022-09-08 00:24:36 +00:00
parent a02198dec0
commit baeff5c294
11 changed files with 44 additions and 38 deletions

View file

@ -12,5 +12,6 @@
"editor.codeActionsOnSave": { "editor.codeActionsOnSave": {
"source.organizeImports": true "source.organizeImports": true
}, },
"git.closeDiffOnOperation": true "git.closeDiffOnOperation": true,
"python.analysis.typeCheckingMode": "basic"
} }

View file

@ -5,25 +5,30 @@ Some useful helpers for working in async contexts.
from asyncio import get_running_loop from asyncio import get_running_loop
from functools import partial, wraps from functools import partial, wraps
from time import time from time import time
from typing import Awaitable, Callable, TypeVar
from async_lru import alru_cache from async_lru import alru_cache
from .settings import SETTINGS from .settings import SETTINGS
RT = TypeVar("RT")
def run_in_executor(f):
def run_in_executor(
function: Callable[..., RT]
) -> Callable[..., Awaitable[RT]]:
""" """
Decorator to make blocking a function call asyncio compatible. Decorator to make blocking a function call asyncio compatible.
https://stackoverflow.com/questions/41063331/how-to-use-asyncio-with-existing-blocking-library/ https://stackoverflow.com/questions/41063331/how-to-use-asyncio-with-existing-blocking-library/
https://stackoverflow.com/a/53719009 https://stackoverflow.com/a/53719009
""" """
@wraps(f) @wraps(function)
async def wrapper(*args, **kwargs): async def wrapper(*args, **kwargs) -> RT:
loop = get_running_loop() loop = get_running_loop()
return await loop.run_in_executor( return await loop.run_in_executor(
None, None,
partial(f, *args, **kwargs), partial(function, *args, **kwargs),
) )
return wrapper return wrapper

View file

@ -69,7 +69,7 @@ class Config(BaseModel):
try: try:
return cls.parse_obj( return cls.parse_obj(
toml_loads(await dav_file.string) toml_loads(await dav_file.as_string)
) )
except RemoteResourceNotFound: except RemoteResourceNotFound:

View file

@ -13,7 +13,7 @@ from typing import Iterator
from caldav import Calendar from caldav import Calendar
from caldav.lib.error import ReportError from caldav.lib.error import ReportError
from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from vobject.icalendar import VEvent from vobject.base import Component
from .async_helpers import get_ttl_hash, run_in_executor, timed_alru_cache from .async_helpers import get_ttl_hash, run_in_executor, timed_alru_cache
from .config import Config from .config import Config
@ -76,7 +76,7 @@ class CalEvent(BaseModel):
)(_string_strip) )(_string_strip)
@classmethod @classmethod
def from_vevent(cls, event: VEvent) -> "CalEvent": def from_vevent(cls, event: Component) -> "CalEvent":
""" """
Create a CalEvent instance from a `VObject.VEvent` object. Create a CalEvent instance from a `VObject.VEvent` object.
""" """
@ -85,7 +85,7 @@ class CalEvent(BaseModel):
for key in cls().dict().keys(): for key in cls().dict().keys():
try: try:
data[key] = event.contents[key][0].value data[key] = event.contents[key][0].value # type: ignore
except KeyError: except KeyError:
pass pass
@ -123,7 +123,7 @@ async def _get_calendar_events(
search_span = timedelta(days=cfg.calendar.future_days) search_span = timedelta(days=cfg.calendar.future_days)
@run_in_executor @run_in_executor
def _inner() -> Iterator[VEvent]: def _inner() -> Iterator[Component]:
""" """
Get events by CalDAV calendar name. Get events by CalDAV calendar name.
@ -156,11 +156,9 @@ async def _get_calendar_events(
expand=False, expand=False,
) )
return ( for event in search_result:
vevent vobject: Component = event.vobject_instance # type: ignore
for event in search_result yield from vobject.vevent_list
for vevent in event.vobject_instance.contents["vevent"]
)
return sorted([ return sorted([
CalEvent.from_vevent(vevent) CalEvent.from_vevent(vevent)
@ -183,7 +181,7 @@ class DavCalendar:
""" """
return await _get_calendar( return await _get_calendar(
ttl_hash=get_ttl_hash(), ttl_hash=get_ttl_hash(), # type: ignore
calendar_name=self.calendar_name, calendar_name=self.calendar_name,
) )
@ -194,6 +192,6 @@ class DavCalendar:
""" """
return await _get_calendar_events( return await _get_calendar_events(
ttl_hash=get_ttl_hash(), ttl_hash=get_ttl_hash(), # type: ignore
calendar_name=self.calendar_name, calendar_name=self.calendar_name,
) )

View file

@ -112,6 +112,6 @@ def caldav_list() -> Iterator[str]:
""" """
return ( return (
cal.name str(cal.name)
for cal in caldav_principal().calendars() for cal in caldav_principal().calendars()
) )

View file

@ -61,12 +61,12 @@ class DavFile:
""" """
return await _get_buffer( return await _get_buffer(
ttl_hash=get_ttl_hash(), ttl_hash=get_ttl_hash(), # type: ignore
remote_path=self.remote_path, remote_path=self.remote_path,
) )
@property @property
async def bytes(self) -> bytes: async def as_bytes(self) -> bytes:
""" """
File contents as binary data. File contents as binary data.
""" """
@ -77,12 +77,12 @@ class DavFile:
return buffer.read() return buffer.read()
@property @property
async def string(self) -> str: async def as_string(self) -> str:
""" """
File contents as string. File contents as string.
""" """
bytes = await self.bytes bytes = await self.as_bytes
return bytes.decode(encoding="utf-8") return bytes.decode(encoding="utf-8")
async def write(self, content: bytes) -> None: async def write(self, content: bytes) -> None:

View file

@ -13,7 +13,6 @@ from ..config import Config
from ..dav_common import caldav_list, webdav_list from ..dav_common import caldav_list, webdav_list
@dataclass(frozen=True)
class NameLister(Protocol): class NameLister(Protocol):
""" """
Can be called to create an iterator containing some names. Can be called to create an iterator containing some names.

View file

@ -50,6 +50,6 @@ async def get_aggregate_calendar(
return sorted([ return sorted([
event event
async for calendar in calendars async for calendar in calendars # type: ignore
for event in (await calendar.events) for event in (await calendar.events)
]) ])

View file

@ -62,12 +62,12 @@ async def find_images(
async def get_image( async def get_image(
prefix: str, prefix: str,
name: str = Depends(image_unique), name: str = Depends(image_unique),
) -> str: ) -> StreamingResponse:
cfg = await Config.get() cfg = await Config.get()
dav_file = DavFile(f"{image_lister.remote_path}/{name}") dav_file = DavFile(f"{image_lister.remote_path}/{name}")
img = Image.open( img = Image.open(
BytesIO(await dav_file.bytes) BytesIO(await dav_file.as_bytes)
).convert( ).convert(
cfg.image.mode cfg.image.mode
) )

View file

@ -34,7 +34,7 @@ text_unique = PrefixUnique(text_finder)
async def get_ticker_lines() -> Iterator[str]: async def get_ticker_lines() -> Iterator[str]:
ticker = await DavFile("text/ticker.txt").string ticker = await DavFile("text/ticker.txt").as_string
return ( return (
line.strip() line.strip()
@ -59,8 +59,9 @@ async def get_ticker_content(
ticker_content_lines: Iterator[str] = Depends(get_ticker_content_lines), ticker_content_lines: Iterator[str] = Depends(get_ticker_content_lines),
) -> str: ) -> str:
cfg = await Config.get() cfg = await Config.get()
ticker_content_lines = ["", *ticker_content_lines, ""] ticker_content = cfg.ticker.separator.join(
ticker_content = cfg.ticker.separator.join(ticker_content_lines) ["", *ticker_content_lines, ""],
)
return ticker_content.strip() return ticker_content.strip()
@ -104,7 +105,7 @@ async def find_texts(
async def get_text_content( async def get_text_content(
name: str = Depends(text_unique), name: str = Depends(text_unique),
) -> str: ) -> str:
return await DavFile(f"{text_lister.remote_path}/{name}").string return await DavFile(f"{text_lister.remote_path}/{name}").as_string
@router.get( @router.get(

View file

@ -7,7 +7,7 @@ Converts per-run (environment) variables and config files into the
Pydantic models might have convenience methods attached. Pydantic models might have convenience methods attached.
""" """
from typing import Optional from typing import Any, Optional
from pydantic import BaseModel, BaseSettings, root_validator from pydantic import BaseModel, BaseSettings, root_validator
@ -17,11 +17,11 @@ class DavSettings(BaseModel):
Connection to a DAV server. Connection to a DAV server.
""" """
protocol: Optional[str] protocol: Optional[str] = None
host: Optional[str] host: Optional[str] = None
username: Optional[str] username: Optional[str] = None
password: Optional[str] password: Optional[str] = None
path: Optional[str] path: Optional[str] = None
@property @property
def url(self) -> str: def url(self) -> str:
@ -65,11 +65,13 @@ class Settings(BaseSettings):
caldav: DavSettings = DavSettings() caldav: DavSettings = DavSettings()
class Config: class Config:
env_file = ".env"
env_file_encoding = "utf-8"
env_nested_delimiter = "__" env_nested_delimiter = "__"
@root_validator(pre=True) @root_validator(pre=True)
@classmethod @classmethod
def validate_dav_settings(cls, values): def validate_dav_settings(cls, values: dict[str, Any]) -> dict[str, Any]:
# ensure both settings dicts are created # ensure both settings dicts are created
for key in ("webdav", "caldav"): for key in ("webdav", "caldav"):
if key not in values: if key not in values:
@ -96,4 +98,4 @@ class Settings(BaseSettings):
return values return values
SETTINGS = Settings(_env_file=".env") SETTINGS = Settings()