mirror of
https://github.com/ldericher/fftcgtool
synced 2025-01-15 15:02:59 +00:00
imageloader.py simplification: multiprocessing
This commit is contained in:
parent
63127a7d9a
commit
4e94ae995d
2 changed files with 18 additions and 60 deletions
|
@ -1,7 +1,6 @@
|
|||
import io
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import multiprocessing
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
@ -9,67 +8,26 @@ from PIL import Image
|
|||
from fftcg.utils import RESOLUTION
|
||||
|
||||
|
||||
class ImageLoader(threading.Thread):
|
||||
def __init__(self, url_queue: queue.Queue):
|
||||
super().__init__()
|
||||
|
||||
self.__queue = url_queue
|
||||
self.__images = {}
|
||||
|
||||
def run(self) -> None:
|
||||
class ImageLoader:
|
||||
@classmethod
|
||||
def _load(cls, url: str) -> Image.Image:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
while not self.__queue.empty():
|
||||
# take next url
|
||||
url = self.__queue.get()
|
||||
# fetch image (retry on fail)
|
||||
while True:
|
||||
logger.info(f"downloading image {url}")
|
||||
try:
|
||||
res = requests.get(url)
|
||||
image = Image.open(io.BytesIO(res.content))
|
||||
|
||||
# fetch image (retry on fail)
|
||||
while True:
|
||||
logger.info(f"downloading image {url}")
|
||||
try:
|
||||
res = requests.get(url)
|
||||
image = Image.open(io.BytesIO(res.content))
|
||||
# unify images
|
||||
image.convert(mode="RGB")
|
||||
return image.resize(RESOLUTION, Image.BICUBIC)
|
||||
|
||||
# unify images
|
||||
image.convert("RGB")
|
||||
image = image.resize(RESOLUTION, Image.BICUBIC)
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# put image in correct position
|
||||
self.__images[url] = image
|
||||
|
||||
# image is processed
|
||||
self.__queue.task_done()
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load(cls, urls: list[str], num_threads: int) -> list[Image.Image]:
|
||||
url_queue = queue.Queue()
|
||||
for url in urls:
|
||||
url_queue.put(url)
|
||||
|
||||
loaders = []
|
||||
for _ in range(num_threads):
|
||||
loader = cls(url_queue)
|
||||
loaders.append(loader)
|
||||
loader.start()
|
||||
|
||||
url_queue.join()
|
||||
|
||||
# stitch all "images" dicts together
|
||||
images = {}
|
||||
for loader in loaders:
|
||||
images |= loader.images
|
||||
|
||||
# sort images to match the initial "urls" list
|
||||
images = [
|
||||
images[url]
|
||||
for url in urls
|
||||
]
|
||||
|
||||
return images
|
||||
|
||||
@property
|
||||
def images(self) -> dict[str, Image.Image]:
|
||||
return self.__images
|
||||
with multiprocessing.Pool(num_threads) as p:
|
||||
return p.map(ImageLoader._load, urls)
|
||||
|
|
2
main.py
2
main.py
|
@ -31,7 +31,7 @@ def main() -> None:
|
|||
args = parser.parse_args()
|
||||
|
||||
# set up logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(threadName)s %(message)s")
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(processName)s %(message)s")
|
||||
|
||||
# output directory
|
||||
if not os.path.exists("out"):
|
||||
|
|
Loading…
Reference in a new issue