Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added image-to-image task for ORT Pipeline #2031

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"ORTModelForSemanticSegmentation",
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
"ORTModelForImageToImage",
],
"modeling_seq2seq": [
"ORTModelForSeq2SeqLM",
Expand Down Expand Up @@ -112,6 +113,7 @@
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForMultipleChoice,
ORTModelForQuestionAnswering,
Expand Down
73 changes: 73 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AutoModelForAudioXVector,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForImageToImage,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
Expand All @@ -47,6 +48,7 @@
BaseModelOutput,
CausalLMOutput,
ImageClassifierOutput,
ImageSuperResolutionOutput,
MaskedLMOutput,
ModelOutput,
MultipleChoiceModelOutput,
Expand Down Expand Up @@ -2183,6 +2185,77 @@ def forward(
return TokenClassifierOutput(logits=logits)


IMAGE_TO_IMAGE_EXAMPLE = r"""
Example of image-to-image (Super Resolution):

```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image

>>> image = Image.open("path/to/image.jpg")

>>> image_processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")

>>> inputs = image_processor(images=image, return_tensors="pt")

>>> with torch.no_grad():
... logits = model(**inputs).logits
```
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageToImage(ORTModel):
"""
ONNX Model for image-to-image tasks. This class officially supports pix2pix, cyclegan, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForImageToImage

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ IMAGE_TO_IMAGE_EXAMPLE.format(
processor_class=_PROCESSOR_FOR_DOC,
model_class="ORTModelForImgageToImage",
checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr",
)
)
def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
use_torch = isinstance(pixel_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
if self.device.type == "cuda" and self.use_io_binding:
input_shapes = pixel_values.shape
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
pixel_values,
ordered_input_names=self._ordered_input_names,
known_output_shapes={
"reconstruction": [
input_shapes[0],
input_shapes[1],
input_shapes[2] * self.config.upscale,
input_shapes[3] * self.config.upscale,
]
},
)
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"])
else:
model_inputs = {"pixel_values": pixel_values}
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
reconstruction = model_outputs["reconstruction"]
return ImageSuperResolutionOutput(reconstruction=reconstruction)


CUSTOM_TASKS_EXAMPLE = r"""
Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output):

Expand Down
8 changes: 8 additions & 0 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FillMaskPipeline,
ImageClassificationPipeline,
ImageSegmentationPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
Pipeline,
PreTrainedTokenizer,
Expand Down Expand Up @@ -55,6 +56,7 @@
ORTModelForCausalLM,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForQuestionAnswering,
ORTModelForSemanticSegmentation,
Expand Down Expand Up @@ -157,6 +159,12 @@
"default": "superb/hubert-base-superb-ks",
"type": "audio",
},
"image-to-image": {
"impl": ImageToImagePipeline,
"class": (ORTModelForImageToImage,),
"default": "caidas/swin2SR-classical-sr-x2-64",
"type": "image",
},
}
else:
ORT_SUPPORTED_TASKS = {}
Expand Down
136 changes: 135 additions & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForImageToImage,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
Expand All @@ -57,7 +58,9 @@
PretrainedConfig,
set_seed,
)
from transformers.modeling_outputs import ImageSuperResolutionOutput
from transformers.modeling_utils import no_init_weights
from transformers.models.swin2sr.configuration_swin2sr import Swin2SRConfig
from transformers.onnx.utils import get_preprocessor
from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow
from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin
Expand All @@ -79,6 +82,7 @@
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForMultipleChoice,
ORTModelForPix2Struct,
Expand Down Expand Up @@ -4704,6 +4708,136 @@ def test_compare_generation_to_io_binding(
gc.collect()


class ORTModelForImageToImageIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = ["swin2sr"]

ORTMODEL_CLASS = ORTModelForImageToImage

TASK = "image-to-image"

def _get_sample_image(self):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image

def _get_preprocessors(self, model_id):
image_processor = AutoImageProcessor.from_pretrained(model_id)

return image_processor

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForImageToImage.from_pretrained(MODEL_NAMES["bert"], export=True)

self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
self.assertIsInstance(onnx_model.config, Swin2SRConfig)
set_seed(SEED)

transformers_model = AutoModelForImageToImage.from_pretrained(model_id)
image_processor = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = image_processor(data, return_tensors="pt")

with torch.no_grad():
transformers_outputs = transformers_model(**features)

onnx_outputs = onnx_model(**features)
self.assertIsInstance(onnx_outputs, ImageSuperResolutionOutput)
self.assertTrue("reconstruction" in onnx_outputs)
self.assertIsInstance(onnx_outputs.reconstruction, torch.Tensor)
self.assertTrue(torch.allclose(onnx_outputs.reconstruction, transformers_outputs.reconstruction, atol=1e-4))

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = image_processor(data, return_tensors="pt")

outputs = onnx_model(**features)
self.assertIsInstance(outputs, ImageSuperResolutionOutput)

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline_image_to_image(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
)
data = self._get_sample_image()
outputs = pipe(data)
self.assertEqual(pipe.device, onnx_model.device)
self.assertIsInstance(outputs, Image.Image)

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_pipeline_on_gpu(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
device=0,
)

data = self._get_sample_image()
outputs = pipe(data)

self.assertEqual(pipe.model.device.type.lower(), "cuda")
self.assertIsInstance(outputs, Image.Image)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_torch_gpu
@require_ort_rocm
@pytest.mark.rocm_ep_test
def test_pipeline_on_rocm(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
device=0,
)

data = self._get_sample_image()
outputs = pipe(data)

self.assertEqual(pipe.model.device.type.lower(), "cuda")
self.assertIsInstance(outputs, Image.Image)


class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = ["vision-encoder-decoder", "trocr", "donut"]

Expand Down Expand Up @@ -4831,7 +4965,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0])
)
for i in range(len(onnx_outputs["past_key_values"])):
print(onnx_outputs["past_key_values"][i])
for ort_pkv, trfs_pkv in zip(
onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i]
):
Expand Down Expand Up @@ -5517,6 +5650,7 @@ class TestBothExportersORTModel(unittest.TestCase):
["automatic-speech-recognition", ORTModelForCTCIntegrationTest],
["audio-xvector", ORTModelForAudioXVectorIntegrationTest],
["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest],
["image-to-image", ORTModelForImageToImageIntegrationTest],
]
)
def test_find_untested_architectures(self, task: str, test_class):
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swin-window": "yujiepan/tiny-random-swin-patch4-window7-224",
"swin2sr": "hf-internal-testing/tiny-random-Swin2SRForImageSuperResolution",
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
"trocr": "microsoft/trocr-small-handwritten",
Expand Down
Loading