Skip to content

Commit

Permalink
feat(video2x): dynamically import optional dependencies
Browse files Browse the repository at this point in the history
Signed-off-by: k4yt3x <[email protected]>
  • Loading branch information
k4yt3x committed Sep 24, 2023
1 parent 37bdfdd commit b382f39
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 18 deletions.
13 changes: 10 additions & 3 deletions video2x/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
"""

import time
from importlib import import_module

from loguru import logger
from PIL import ImageChops, ImageStat
from rife_ncnn_vulkan_python.rife_ncnn_vulkan import Rife

from .processor import Processor


class Interpolator:
ALGORITHM_CLASSES = {"rife": Rife}
ALGORITHM_CLASSES = {"rife": "rife_ncnn_vulkan_python.rife_ncnn_vulkan.Rife"}

processor_objects = {}

Expand All @@ -43,9 +43,16 @@ def interpolate_image(self, image0, image1, difference_threshold, algorithm):

if difference_ratio < difference_threshold:
processor_object = self.processor_objects.get(algorithm)

if processor_object is None:
processor_object = self.ALGORITHM_CLASSES[algorithm](0)
module_name, class_name = self.ALGORITHM_CLASSES[algorithm].rsplit(
".", 1
)
processor_module = import_module(module_name)
processor_class = getattr(processor_module, class_name)
processor_object = processor_class(0)
self.processor_objects[algorithm] = processor_object

interpolated_image = processor_object.process(image0, image1)

else:
Expand Down
23 changes: 11 additions & 12 deletions video2x/upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,9 @@

import math
import time
from importlib import import_module

from anime4k_python import Anime4K
from PIL import Image
from realcugan_ncnn_vulkan_python import Realcugan
from realsr_ncnn_vulkan_python import Realsr
from srmd_ncnn_vulkan_python import Srmd
from waifu2x_ncnn_vulkan_python import Waifu2x

from .processor import Processor

Expand All @@ -45,11 +41,11 @@ class Upscaler:
}

ALGORITHM_CLASSES = {
"anime4k": Anime4K,
"realcugan": Realcugan,
"realsr": Realsr,
"srmd": Srmd,
"waifu2x": Waifu2x,
"anime4k": "anime4k_python.Anime4K",
"realcugan": "realcugan_ncnn_vulkan_python.Realcugan",
"realsr": "realsr_ncnn_vulkan_python.Realsr",
"srmd": "srmd_ncnn_vulkan_python.Srmd",
"waifu2x": "waifu2x_ncnn_vulkan_python.Waifu2x",
}

processor_objects = {}
Expand Down Expand Up @@ -148,9 +144,12 @@ def upscale_image(
# create a new object if none are available
processor_object = self.processor_objects.get((algorithm, task))
if processor_object is None:
processor_object = self.ALGORITHM_CLASSES[algorithm](
noise=noise, scale=task
module_name, class_name = self.ALGORITHM_CLASSES[algorithm].rsplit(
".", 1
)
processor_module = import_module(module_name)
processor_class = getattr(processor_module, class_name)
processor_object = processor_class(noise=noise, scale=task)
self.processor_objects[(algorithm, task)] = processor_object

# process the image with the selected algorithm
Expand Down
13 changes: 10 additions & 3 deletions video2x/video2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
import sys
import time
from enum import Enum
from importlib import import_module
from multiprocessing import Manager, Pool, Queue, Value
from pathlib import Path
from typing import Any, Callable, Optional
from typing import Callable, Optional

import ffmpeg
from cv2 import cv2
Expand Down Expand Up @@ -156,9 +157,12 @@ def _run(
# process by directly invoking the
# if the selected algorithm does not support frameserving
if mode == ProcessingMode.UPSCALE:
standalone_processor: Any = Upscaler.ALGORITHM_CLASSES[
standalone_processor_path: str = Upscaler.ALGORITHM_CLASSES[
processing_settings[2]
]
module_name, class_name = standalone_processor_path.rsplit(".", 1)
processor_module = import_module(module_name)
standalone_processor = getattr(processor_module, class_name)
if getattr(standalone_processor, "process", None) is None:
logger.warning("No progress bar available for this processor")
standalone_processor().process_video(
Expand All @@ -172,9 +176,12 @@ def _run(
return
# elif mode == ProcessingMode.INTERPOLATE:
else:
standalone_processor: Any = Interpolator.ALGORITHM_CLASSES[
standalone_processor_path: str = Interpolator.ALGORITHM_CLASSES[
processing_settings[1]
]
module_name, class_name = standalone_processor_path.rsplit(".", 1)
processor_module = import_module(module_name)
standalone_processor = getattr(processor_module, class_name)
if getattr(standalone_processor, "process", None) is None:
logger.warning("No progress bar available for this processor")
standalone_processor().process_video(
Expand Down

0 comments on commit b382f39

Please sign in to comment.