1
0
Fork 0
mirror of https://github.com/ldericher/fftcgtool synced 2025-01-15 15:02:59 +00:00

imageloader.py simplification: multiprocessing

This commit is contained in:
Jörn-Michael Miehe 2021-08-23 13:36:19 +02:00
parent 63127a7d9a
commit 4e94ae995d
2 changed files with 18 additions and 60 deletions

View file

@ -1,7 +1,6 @@
import io
import logging
import queue
import threading
import multiprocessing
import requests
from PIL import Image
@ -9,67 +8,26 @@ from PIL import Image
from fftcg.utils import RESOLUTION
class ImageLoader(threading.Thread):
def __init__(self, url_queue: queue.Queue):
super().__init__()
self.__queue = url_queue
self.__images = {}
def run(self) -> None:
class ImageLoader:
@classmethod
def _load(cls, url: str) -> Image.Image:
logger = logging.getLogger(__name__)
while not self.__queue.empty():
# take next url
url = self.__queue.get()
# fetch image (retry on fail)
while True:
logger.info(f"downloading image {url}")
try:
res = requests.get(url)
image = Image.open(io.BytesIO(res.content))
# fetch image (retry on fail)
while True:
logger.info(f"downloading image {url}")
try:
res = requests.get(url)
image = Image.open(io.BytesIO(res.content))
# unify images
image.convert(mode="RGB")
return image.resize(RESOLUTION, Image.BICUBIC)
# unify images
image.convert("RGB")
image = image.resize(RESOLUTION, Image.BICUBIC)
break
except requests.exceptions.RequestException:
pass
# put image in correct position
self.__images[url] = image
# image is processed
self.__queue.task_done()
except requests.exceptions.RequestException:
pass
@classmethod
def load(cls, urls: list[str], num_threads: int) -> list[Image.Image]:
url_queue = queue.Queue()
for url in urls:
url_queue.put(url)
loaders = []
for _ in range(num_threads):
loader = cls(url_queue)
loaders.append(loader)
loader.start()
url_queue.join()
# stitch all "images" dicts together
images = {}
for loader in loaders:
images |= loader.images
# sort images to match the initial "urls" list
images = [
images[url]
for url in urls
]
return images
@property
def images(self) -> dict[str, Image.Image]:
return self.__images
with multiprocessing.Pool(num_threads) as p:
return p.map(ImageLoader._load, urls)

View file

@ -31,7 +31,7 @@ def main() -> None:
args = parser.parse_args()
# set up logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(threadName)s %(message)s")
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(processName)s %(message)s")
# output directory
if not os.path.exists("out"):