diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 37d0feefcc..6079799c20 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -129,6 +129,7 @@ class ORTConfigManager: "pegasus": "bert", "roberta": "bert", "segformer": "vit", + "table-transformer": "vit", "t5": "bert", "vit": "vit", "whisper": "bart", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 81207b7649..22b37d2d59 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -217,7 +217,6 @@ class NormalizedConfigManager: 'perceiver', 'roformer', 'squeezebert', - 'table-transformer', """ # Contribution note: Please add new models in alphabetical order @@ -273,6 +272,7 @@ class NormalizedConfigManager: "segformer": NormalizedSegformerConfig, "speech-to-text": SpeechToTextLikeNormalizedTextConfig, "splinter": NormalizedTextConfig, + "table-transformer": NormalizedVisionConfig, "t5": T5LikeNormalizedTextConfig, "trocr": TrOCRLikeNormalizedTextConfig, "vision-encoder-decoder": NormalizedEncoderDecoderConfig, diff --git a/tests/onnxruntime/test_optimization.py b/tests/onnxruntime/test_optimization.py index 82109fcd11..e326910cb4 100644 --- a/tests/onnxruntime/test_optimization.py +++ b/tests/onnxruntime/test_optimization.py @@ -35,6 +35,7 @@ from optimum.onnxruntime import ( AutoOptimizationConfig, ORTConfig, + ORTModelForCustomTasks, ORTModelForImageClassification, ORTModelForSemanticSegmentation, ORTModelForSequenceClassification, @@ -172,6 +173,7 @@ def test_compare_original_seq2seq_model_with_optimized_model(self, model_cls, mo # Contribution note: Please add test models in alphabetical order. Find test models here: https://huggingface.co/hf-internal-testing. SUPPORTED_IMAGE_ARCHITECTURES_WITH_MODEL_ID = ( + (ORTModelForCustomTasks, "hf-internal-testing/tiny-random-TableTransformerModel"), (ORTModelForSemanticSegmentation, "hf-internal-testing/tiny-random-segformer"), (ORTModelForImageClassification, "hf-internal-testing/tiny-random-vit"), ) @@ -191,11 +193,18 @@ def test_compare_original_image_model_with_optimized_model(self, model_cls, mode # Verify the ORTConfig was correctly created and saved self.assertEqual(ort_config.to_dict(), expected_ort_config.to_dict()) - image = torch.ones((1, model.config.num_channels, model.config.image_size, model.config.image_size)) - model_outputs = model(image) - optimized_model_outputs = optimized_model(image) + image_size = getattr(model.config, "image_size", 224) + image = torch.ones((1, model.config.num_channels, image_size, image_size)) + model_outputs = model(pixel_values=image) + optimized_model_outputs = optimized_model(pixel_values=image) + # Compare tensors outputs - self.assertTrue(torch.equal(model_outputs.logits, optimized_model_outputs.logits)) + if hasattr(model_outputs, 'logits'): + self.assertTrue(torch.equal(model_outputs.logits, optimized_model_outputs.logits), "Logits do not match") + elif hasattr(model_outputs, 'last_hidden_state'): + self.assertTrue(torch.equal(model_outputs.last_hidden_state, optimized_model_outputs.last_hidden_state), "last_hidden_state does not match") + else: + raise ValueError("Model outputs do not have logits or last_hidden_state") gc.collect() def test_optimization_details(self):