Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 21, 2024
1 parent 16b3126 commit 7cedba5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions examples/post_training_quantization/torch/ssd300_vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,19 @@ def get_image():


def main_detr():
from transformers import DetrImageProcessor, DetrForObjectDetection
from transformers import DetrImageProcessor, DetrForObjectDetection # noqa
from transformers import AutoImageProcessor, AutoModelForObjectDetection, ConditionalDetrForObjectDetection # noqa
from transformers import OwlViTProcessor, OwlViTForObjectDetection # noqa
import torch

device = torch.device("cpu")
# you can specify the revision tag if you don't want the timm dependency
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
# processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
# model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50")
# processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
# model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
model.eval()

dataset_path = download_dataset()
Expand Down

0 comments on commit 7cedba5

Please sign in to comment.