Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenCLIP sample for visual question answering #977

Merged
merged 10 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ slow_tests_text_generation_example: test_installs
slow_tests_image_to_text_example: test_installs
python -m pytest tests/test_image_to_text_example.py -v -s --token $(TOKEN)

# Run visual question answering tests
slow_tests_openclip_vqa_example: test_installs
python -m pip install -r examples/visual-question-answering/openclip_requirements.txt
python -m pytest tests/test_openclip_vqa.py

slow_tests_fsdp: test_installs
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)

Expand Down
39 changes: 36 additions & 3 deletions examples/visual-question-answering/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ limitations under the License.

# Visual Question Answering Examples

This directory contains a script that showcases how to use the Transformers pipeline API to run visual question answering task on HPUs.

## Single-HPU inference

The `run_pipeline.py` script showcases how to use the Transformers pipeline API to run visual question answering task on HPUs.

```bash
python3 run_pipeline.py \
--model_name_or_path Salesforce/blip-vqa-capfilt-large \
Expand All @@ -32,4 +32,37 @@ python3 run_pipeline.py \
Models that have been validated:
- [Salesforce/blip-vqa-base](https://huggingface.co/Salesforce/blip-vqa-base)
- [dandelin/vilt-b32-finetuned-vqa](https://huggingface.co/dandelin/vilt-b32-finetuned-vqa)
- [Salesforce/blip-vqa-capfilt-large](https://huggingface.co/Salesforce/blip-vqa-capfilt-large)
- [Salesforce/blip-vqa-capfilt-large](https://huggingface.co/Salesforce/blip-vqa-capfilt-large)

## OpenCLIP inference

The `run_openclip_vqa.py` can be used to run zero shot image classification with [OpenCLIP Huggingface Models](https://huggingface.co/docs/hub/en/open_clip#using-openclip-at-hugging-face).
The requirements for `run_openclip_vqa.py` can be installed with `openclip_requirements.txt` as follows:

```bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am getting very large throughput.
is this even a realistic usecase?

06/18/2024 16:49:42 - INFO - __main__ - Running warmup
06/18/2024 16:49:43 - INFO - __main__ - Running inference
06/18/2024 16:49:43 - INFO - __main__ - Inference Time per iteration = 5.014ms
06/18/2024 16:49:43 - INFO - __main__ - Throughput = 1.795e+03 images/s

Copy link
Contributor Author

@vidyasiv vidyasiv Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does seem high..I don't recall seeing that large a throughput

pip install -r openclip_requirements.txt
```

By default, the script runs the sample outlined in [BiomedCLIP-PubMedBERT_256-vit_base_patch16_224 notebook](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/blob/main/biomed_clip_example.ipynb) which can be run as follows:

```bash
python run_openclip_vqa.py \
--use_hpu_graphs \
--bf16
```

One can also run other OpenCLIP models by specifying model, classifier labels and image URL(s) like so:

vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
```bash
python run_openclip_vqa.py \
--model_name_or_path laion/CLIP-ViT-g-14-laion2B-s12B-b42K \
--labels "a dog" "a cat" \
--image_path "http://images.cocodataset.org/val2017/000000039769.jpg" \
--use_hpu_graphs \
--bf16
```

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likely needs to be valided on Gaudi1 as well.

Models that have been validated:
- [BiomedCLIP-PubMedBERT_256-vit_base_patch16_224](https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224)
- [laion/CLIP-ViT-g-14-laion2B-s12B-b42K](https://huggingface.co/laion/CLIP-ViT-g-14-laion2B-s12B-b42K)
- [apple/DFN5B-CLIP-ViT-H-14](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14/tree/main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
open_clip_torch==2.23.0
matplotlib

232 changes: 232 additions & 0 deletions examples/visual-question-answering/run_openclip_vqa.py
vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# This script is based on https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/blob/main/biomed_clip_example.ipynb
import argparse
import json
import logging
import os
import time
from pathlib import Path
from pprint import pprint
from urllib.request import urlopen

import matplotlib.pyplot as plt
import numpy
import torch
from open_clip import create_model_from_pretrained, get_tokenizer, model
from PIL import Image

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi


logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

DATASET_URL = "https://huggingface.co/microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/resolve/main/example_data/biomed_image_classification_example_data/"
LABELS = [
"adenocarcinoma histopathology",
"brain MRI",
"covid line chart",
"squamous cell carcinoma histopathology",
"immunohistochemistry histopathology",
"bone X-ray",
"chest X-ray",
"pie chart",
"hematoxylin and eosin histopathology",
]

TEST_IMGS = [
"squamous_cell_carcinoma_histopathology.jpeg",
"H_and_E_histopathology.jpg",
"bone_X-ray.jpg",
"adenocarcinoma_histopathology.jpg",
"covid_line_chart.png",
"IHC_histopathology.jpg",
"chest_X-ray.jpg",
"brain_MRI.jpg",
"pie_chart.png",
]


def plot_images_with_metadata(images: list, metadata, output_dir: str, plot_name: str) -> None:
print(f"plottypes {type(images)} {type(metadata)} {type(output_dir)} {type(plot_name)}")

num_images = len(images)
fig, axes = plt.subplots(nrows=num_images, ncols=1, figsize=(5, 5 * num_images))

for i, (img_path, metadata) in enumerate(zip(images, metadata)):
img = Image.open(urlopen(img_path))
if isinstance(axes, list) or isinstance(axes, numpy.ndarray):
ax = axes[i]
else:
ax = axes
ax.imshow(img)
ax.axis("off")
ax.set_title(f"{metadata['filename']}\n{metadata['top_probs']}", fontsize=14)

plt.tight_layout()
plt.savefig(f"{output_dir}/{plot_name}.png")


def run_qa(model: model, images: torch.Tensor, texts: torch.Tensor, device: torch.device) -> tuple:
with torch.no_grad():
image_features, text_features, logit_scale = model(images, texts)
logits = (logit_scale * image_features @ text_features.t()).detach().softmax(dim=-1)
sorted_indices = torch.argsort(logits, dim=-1, descending=True)
return sorted_indices, logits


def postprocess(args: argparse.Namespace, sorted_indices: torch.Tensor, logits: torch.Tensor, topk: int) -> list:
logits = logits.float().cpu().numpy()
sorted_indices = sorted_indices.int().cpu().numpy()
metadata_list = []
for i, img in enumerate(args.image_path):
img_name = img.split("/")[-1]

top_probs = []
topk = len(args.labels) if topk == -1 else topk
for j in range(topk):
jth_index = sorted_indices[i][j]
top_probs.append(f"{args.labels[jth_index]}: {logits[i][jth_index] * 100:.1f}")

metadata = {"filename": img_name, "top_probs": "\n".join(top_probs)}
metadata_list.append(metadata)
return metadata_list


def main():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model_name_or_path",
default="microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224",
type=str,
help="Path to pre-trained model",
)
parser.add_argument(
"--image_path",
default=[DATASET_URL + img for img in TEST_IMGS],
type=str,
nargs="*",
help='Path to image as input. Can be a single string (eg: --image_path "URL1"), or a list of space-separated strings (eg: --image_path "URL1" "URL2")',
)
parser.add_argument(
"--topk",
default=1,
type=int,
help="topk num. Provides top K probabilities for the labels provided.",
)
parser.add_argument(
"--prompt",
default="this is a picture of ",
type=str,
help='Prompt for classification. It should be a string separated by comma. (eg: --prompt "a photo of ")',
)
parser.add_argument(
"--labels",
default=LABELS,
type=str,
nargs="*",
help='Labels for classification (eg: --labels "LABEL1"), or a list of space-separated strings (eg: --labels "LABEL1" "LABEL2")',
)
parser.add_argument(
"--use_hpu_graphs",
action="store_true",
help="Whether to use HPU graphs or not. Using HPU graphs should give better latencies.",
)
parser.add_argument(
"--bf16",
action="store_true",
help="Whether to perform in bf16 precision.",
)
parser.add_argument(
"--output_dir",
default=os.getcwd(),
type=str,
help="Output directory to store results in.",
)
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument(
"--n_iterations", type=int, default=10, help="Number of inference iterations for benchmarking."
)
parser.add_argument("--plot_images", action="store_true", help="Plot images with metadata for verification")
parser.add_argument(
"--plot_name",
default="openclip_vqa_plot",
type=str,
help="Name of the plot generated with the image and corresponding top K results",
)
parser.add_argument(
"--print_result",
action="store_true",
help="Whether to print the zero shot classification results.",
)

args = parser.parse_args()

adapt_transformers_to_gaudi()

precision = "fp32"
dtype = torch.float32
if args.bf16:
precision = "bf16"
dtype = torch.bfloat16

model, preprocess = create_model_from_pretrained(f"hf-hub:{args.model_name_or_path}", precision=precision)
tokenizer = get_tokenizer(f"hf-hub:{args.model_name_or_path}")

device = torch.device("hpu") if torch.hpu.is_available() else torch.device("cpu")
device_type = "hpu" if torch.hpu.is_available() else "cpu"

# Initialize model
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
model = model.to(device)
model.eval()

images = torch.stack([preprocess(Image.open(urlopen(img))) for img in args.image_path]).to(device)
texts = tokenizer([args.prompt + l for l in args.labels]).to(device)

# Warm up
logger.info("Running warmup")
for i in range(args.warmup):
with torch.autocast(device_type=device_type, dtype=dtype, enabled=True):
_, _ = run_qa(model, images, texts, device=device)

logger.info("Running inference")
start = time.time()
for i in range(args.n_iterations):
logits = None
with torch.autocast(device_type=device_type, dtype=dtype, enabled=True):
sorted_indices, logits = run_qa(model, images, texts, device=device)
end = time.time()

# Results and metrics
metadata_list = []
metadata_list = postprocess(args, sorted_indices, logits, args.topk)
if args.print_result:
logger.info("Results from the last iteration:")
pprint(metadata_list)
inference_time_per_iteration = (end - start) * 1000 / args.n_iterations
logger.info(f"Inference Time per iteration = {inference_time_per_iteration:.4}ms")
throughput = len(args.image_path) * args.n_iterations / (end - start)
logger.info(f"Throughput = {throughput:.4} images/s")

# Store results if necessary
if args.output_dir is not None:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

results = {"throughput": throughput, "inference time per iteration ": inference_time_per_iteration}
with (output_dir / "results.json").open("w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)
if args.plot_images:
plot_images_with_metadata(args.image_path, metadata_list, args.output_dir, args.plot_name)


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions tests/test_openclip_vqa.py
vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import os
import re
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest

from .test_examples import TIME_PERF_FACTOR


if os.environ.get("GAUDI2_CI", "0") == "1":
# Gaudi2 CI baselines
MODELS_TO_TEST = {
vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
"bf16": [
("laion/CLIP-ViT-g-14-laion2B-s12B-b42K", 1472),
("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1816),
],
}
vidyasiv marked this conversation as resolved.
Show resolved Hide resolved
else:
# Gaudi1 CI baselines
MODELS_TO_TEST = {
"bf16": [
("laion/CLIP-ViT-g-14-laion2B-s12B-b42K", 550),
("microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224", 1200),
],
}


def _install_requirements():
PATH_TO_EXAMPLE_DIR = Path(__file__).resolve().parent.parent / "examples"
cmd_line = (
f"pip install -r {PATH_TO_EXAMPLE_DIR / 'visual-question-answering' / 'openclip_requirements.txt'}".split()
)
p = subprocess.Popen(cmd_line)
return_code = p.wait()
assert return_code == 0


def _test_openclip_vqa(model_name: str, baseline: float):
_install_requirements()
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
env_variables = os.environ.copy()

command += [
f"{path_to_example_dir / 'visual-question-answering' / 'run_openclip_vqa.py'}",
f"--model_name_or_path {model_name}",
"--bf16",
"--use_hpu_graphs",
]

with TemporaryDirectory() as tmp_dir:
command.append(f"--output_dir {tmp_dir}")
print(f"\n\nCommand to test: {' '.join(command)}\n")

pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
command = [x for y in command for x in re.split(pattern, y) if x]

proc = subprocess.run(command, env=env_variables)

# Ensure the run finished without any issue
# Use try-except to avoid logging the token if used
try:
assert proc.returncode == 0
except AssertionError as e:
if "'--token', 'hf_" in e.args[0]:
e.args = (f"The following command failed:\n{' '.join(command[:-2])}",)
raise

with open(Path(tmp_dir) / "results.json") as fp:
results = json.load(fp)

# Ensure performance requirements (throughput) are met
assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline


@pytest.mark.parametrize("model_name, baseline", MODELS_TO_TEST["bf16"])
def test_openclip_vqa_bf16(model_name: str, baseline: float):
_test_openclip_vqa(model_name, baseline)
Loading