diff --git a/fftcg/imageloader.py b/fftcg/imageloader.py index 28459b8..26cdb0b 100644 --- a/fftcg/imageloader.py +++ b/fftcg/imageloader.py @@ -1,7 +1,6 @@ import io import logging -import queue -import threading +import multiprocessing import requests from PIL import Image @@ -9,67 +8,26 @@ from PIL import Image from fftcg.utils import RESOLUTION -class ImageLoader(threading.Thread): - def __init__(self, url_queue: queue.Queue): - super().__init__() - - self.__queue = url_queue - self.__images = {} - - def run(self) -> None: +class ImageLoader: + @classmethod + def _load(cls, url: str) -> Image.Image: logger = logging.getLogger(__name__) - while not self.__queue.empty(): - # take next url - url = self.__queue.get() + # fetch image (retry on fail) + while True: + logger.info(f"downloading image {url}") + try: + res = requests.get(url) + image = Image.open(io.BytesIO(res.content)) - # fetch image (retry on fail) - while True: - logger.info(f"downloading image {url}") - try: - res = requests.get(url) - image = Image.open(io.BytesIO(res.content)) + # unify images + image.convert(mode="RGB") + return image.resize(RESOLUTION, Image.BICUBIC) - # unify images - image.convert("RGB") - image = image.resize(RESOLUTION, Image.BICUBIC) - break - except requests.exceptions.RequestException: - pass - - # put image in correct position - self.__images[url] = image - - # image is processed - self.__queue.task_done() + except requests.exceptions.RequestException: + pass @classmethod def load(cls, urls: list[str], num_threads: int) -> list[Image.Image]: - url_queue = queue.Queue() - for url in urls: - url_queue.put(url) - - loaders = [] - for _ in range(num_threads): - loader = cls(url_queue) - loaders.append(loader) - loader.start() - - url_queue.join() - - # stitch all "images" dicts together - images = {} - for loader in loaders: - images |= loader.images - - # sort images to match the initial "urls" list - images = [ - images[url] - for url in urls - ] - - return images - - @property - def images(self) -> dict[str, Image.Image]: - return self.__images + with multiprocessing.Pool(num_threads) as p: + return p.map(ImageLoader._load, urls) diff --git a/main.py b/main.py index ee97525..27ea821 100755 --- a/main.py +++ b/main.py @@ -31,7 +31,7 @@ def main() -> None: args = parser.parse_args() # set up logging - logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(threadName)s %(message)s") + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(processName)s %(message)s") # output directory if not os.path.exists("out"):