Skip to content

Commit

Permalink
updates per PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
vidyasiv committed Jun 20, 2024
1 parent ca5aba0 commit 2dd43c5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ slow_tests_text_generation_example: test_installs
slow_tests_image_to_text_example: test_installs
python -m pytest tests/test_image_to_text_example.py -v -s --token $(TOKEN)

# Run visual question answering tests
slow_tests_openclip_vqa_example: test_installs
python -m pip install -r examples/visual-question-answering/openclip_requirements.txt
python -m pytest tests/test_openclip_vqa.py

slow_tests_fsdp: test_installs
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)

Expand Down
10 changes: 6 additions & 4 deletions examples/visual-question-answering/run_openclip_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import matplotlib.pyplot as plt
import numpy
import torch
from open_clip import create_model_from_pretrained, get_tokenizer
from open_clip import create_model_from_pretrained, get_tokenizer, model
from PIL import Image

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
Expand Down Expand Up @@ -49,7 +49,9 @@
'pie_chart.png'
]

def plot_images_with_metadata(images, metadata, output_dir, plot_name):
def plot_images_with_metadata(images:list, metadata, output_dir: str, plot_name: str) -> None:
print(f"plottypes {type(images)} {type(metadata)} {type(output_dir)} {type(plot_name)}")

num_images = len(images)
fig, axes = plt.subplots(nrows=num_images, ncols=1, figsize=(5, 5 * num_images))

Expand All @@ -67,15 +69,15 @@ def plot_images_with_metadata(images, metadata, output_dir, plot_name):
plt.savefig(f'{output_dir}/{plot_name}.png')


def run_qa(model, images, texts, device):
def run_qa(model: model, images: torch.Tensor, texts: torch.Tensor, device: torch.device) -> tuple:
with torch.no_grad():
image_features, text_features, logit_scale = model(images, texts)
logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
return sorted_indices, logits


def postprocess(args, sorted_indices, logits, topk):
def postprocess(args: argparse.Namespace, sorted_indices: torch.Tensor, logits: torch.Tensor , topk: int) -> list:
logits = logits.float().cpu().numpy()
sorted_indices = sorted_indices.int().cpu().numpy()
metadata_list = []
Expand Down
10 changes: 9 additions & 1 deletion tests/test_openclip_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1816)
],
}
else:
# Gaudi1 CI baselines
MODELS_TO_TEST = {
"bf16": [
("laion/CLIP-ViT-g-14-laion2B-s12B-b42K", 550),
("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1200),
],
}


def _install_requirements():
PATH_TO_EXAMPLE_DIR = Path(__file__).resolve().parent.parent / "examples"
Expand Down Expand Up @@ -64,7 +73,6 @@ def _test_openclip_vqa(
results = json.load(fp)

# Ensure performance requirements (throughput) are met
t= results["throughput"]
assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline


Expand Down

0 comments on commit 2dd43c5

Please sign in to comment.