Skip to content

Commit

Permalink
An example for Timm ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jul 26, 2023
1 parent a07659e commit 71028c1
Showing 1 changed file with 0 additions and 20 deletions.
20 changes: 0 additions & 20 deletions examples/timm_models/timm_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,6 @@ def get_channel_groups(self, layer):

timm_atten_pruner = TimmAttentionPruner()


from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
print(model)
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])



for i, model_name in enumerate(timm_models):
if not model_name=='vit_base_patch8_224':
continue
Expand Down

0 comments on commit 71028c1

Please sign in to comment.