Skip to content

Commit

Permalink
#10439: ttnn implementation of vgg model
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Sep 16, 2024
1 parent 23b968f commit 8b63ee5
Show file tree
Hide file tree
Showing 8 changed files with 982 additions and 0 deletions.
20 changes: 20 additions & 0 deletions models/demos/functional_vgg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Introduction

The VGG model is a popular convolutional neural network architecture introduced by the Visual Geometry Group at Oxford in their paper "Very Deep Convolutional Networks for Large-Scale Image Recognition" (2014). It is widely used for image classification and feature extraction tasks.

# Model Architectures
- VGG11
- VGG16

# How to Run
To run the demo for image classification of the VGG model using ImageNet-1k Validation Dataset, follow these instructions

- Use the following command to run the model using ttnn_vgg
-VGG11
```
pytest models/experimental/functional_vgg/demo/demo.py::test_demo_imagenet_vgg11
```
- VGG16
```
pytest models/demos/functional_vgg/demo/demo.py::test_demo_imagenet_vgg16
```
153 changes: 153 additions & 0 deletions models/demos/functional_vgg/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


import torch
from loguru import logger
from torchvision import models
from transformers import AutoImageProcessor
import pytest
import tt_lib
import torch.nn as nn

from models.utility_functions import (
disable_compilation_reports,
disable_persistent_kernel_cache,
enable_persistent_kernel_cache,
profiler,
)
import ttnn

from models.demos.functional_vgg.demo_utils import get_data, get_data_loader, get_batch, preprocess
from loguru import logger
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.functional_vgg.tt import ttnn_vgg

vgg_model_config = {
"MATH_FIDELITY": ttnn.MathFidelity.LoFi,
"WEIGHTS_DTYPE": ttnn.bfloat16,
"ACTIVATIONS_DTYPE": ttnn.bfloat16,
}


def run_vgg_imagenet_inference_vgg16(
batch_size, iterations, imagenet_label_dict, model_location_generator, device, model_config=vgg_model_config
):
disable_persistent_kernel_cache()
disable_compilation_reports()
profiler.clear()

# Setup model
torch_model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
torch_model.to(torch.bfloat16)
torch_model.eval()

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model,
device=device,
convert_to_ttnn=lambda *_: True,
custom_preprocessor=ttnn_vgg.custom_preprocessor,
)

# load inputs
logger.info("ImageNet-1k validation Dataset")
input_loc = str(model_location_generator("ImageNet_data"))
data_loader = get_data_loader(input_loc, batch_size, iterations)

# load ImageNet batch by batch
# and run inference
correct = 0
for iter in range(iterations):
predictions = []
torch_predictions = []
inputs, labels = get_batch(data_loader)
torch_outputs = torch_model(inputs)
permuted_inputs = torch.permute(inputs, (0, 2, 3, 1))
tt_batched_input_tensor = ttnn.from_torch(permuted_inputs, ttnn.bfloat16)
tt_output = ttnn_vgg.ttnn_vgg16(device, tt_batched_input_tensor, parameters, batch_size, model_config)
tt_output = ttnn.to_torch(tt_output)
prediction = tt_output[:, 0, 0, :].argmax(dim=-1)
torch_prediction = torch_outputs[:, :].argmax(dim=-1)
for i in range(batch_size):
predictions.append(imagenet_label_dict[prediction[i].item()])
torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()])
logger.info(
f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}"
)
if imagenet_label_dict[labels[i]] == predictions[-1]:
correct += 1

del tt_output, tt_batched_input_tensor, inputs, labels, predictions
accuracy = correct / (batch_size * iterations)
logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}")


def run_vgg_imagenet_inference_vgg11(
batch_size, iterations, imagenet_label_dict, model_location_generator, device, model_config=vgg_model_config
):
disable_persistent_kernel_cache()
disable_compilation_reports()
profiler.clear()

# Setup model
torch_model = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
torch_model.to(torch.bfloat16)
torch_model.eval()

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_model,
device=device,
convert_to_ttnn=lambda *_: True,
custom_preprocessor=ttnn_vgg.custom_preprocessor,
)

# load inputs
logger.info("ImageNet-1k validation Dataset")
input_loc = str(model_location_generator("ImageNet_data"))
data_loader = get_data_loader(input_loc, batch_size, iterations)

# load ImageNet batch by batch
# and run inference
correct = 0
for iter in range(iterations):
predictions = []
torch_predictions = []
inputs, labels = get_batch(data_loader)
torch_outputs = torch_model(inputs)
permuted_inputs = torch.permute(inputs, (0, 2, 3, 1))
tt_batched_input_tensor = ttnn.from_torch(permuted_inputs, ttnn.bfloat16)
tt_output = ttnn_vgg.ttnn_vgg11(device, tt_batched_input_tensor, parameters, batch_size, model_config)
tt_output = ttnn.to_torch(tt_output)
prediction = tt_output[:, 0, 0, :].argmax(dim=-1)
torch_prediction = torch_outputs[:, :].argmax(dim=-1)
for i in range(batch_size):
predictions.append(imagenet_label_dict[prediction[i].item()])
torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()])
logger.info(
f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}"
)
if imagenet_label_dict[labels[i]] == predictions[-1]:
correct += 1

del tt_output, tt_batched_input_tensor, inputs, labels, predictions
accuracy = correct / (batch_size * iterations)
logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}")


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, iterations",
((1, 1),),
)
def test_demo_imagenet_vgg11(batch_size, iterations, imagenet_label_dict, model_location_generator, device):
run_vgg_imagenet_inference_vgg11(batch_size, iterations, imagenet_label_dict, model_location_generator, device)


@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True)
@pytest.mark.parametrize(
"batch_size, iterations",
((1, 1),),
)
def test_demo_imagenet_vgg16(batch_size, iterations, imagenet_label_dict, model_location_generator, device):
run_vgg_imagenet_inference_vgg16(batch_size, iterations, imagenet_label_dict, model_location_generator, device)
130 changes: 130 additions & 0 deletions models/demos/functional_vgg/demo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from PIL import Image
import torch
import os
import glob
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES
from datasets import load_dataset
from torchvision import models
from PIL import Image
import torchvision.transforms as transforms
import torch


class InputExample(object):
def __init__(self, image, label=None):
self.image = image
self.label = label


def get_input(image_path):
img = Image.open(image_path)
return img


def get_label(image_path):
_, image_name = image_path.rsplit("/", 1)
image_name_exact, _ = image_name.rsplit(".", 1)
_, label_id = image_name_exact.rsplit("_", 1)
label = list(IMAGENET2012_CLASSES).index(label_id)
return label


preprocess = transforms.Compose(
[
transforms.Resize(256), # Resize the shorter side to 256 pixels
transforms.CenterCrop(224), # Crop the center to 224x224 pixels
transforms.ToTensor(), # Convert the image to a tensor
transforms.Normalize( # Normalize using ImageNet's mean and std
mean=[0.485, 0.456, 0.406], # These are the mean values for each channel
std=[0.229, 0.224, 0.225], # These are the std values for each channel
),
]
)


def get_batch(data_loader):
loaded_images = next(data_loader)
images = None
labels = []
transform = transforms.ToTensor()
resize_transform = transforms.Resize((224, 224))
for image in loaded_images:
img = image.image
labels.append(image.label)
if img.mode == "L":
img = img.convert(mode="RGB")

img = preprocess(img)
img = img.to(torch.bfloat16)
img = img.unsqueeze(0)
if images is None:
images = img
else:
images = torch.cat((images, img), dim=0)
return images, labels


def get_data_loader(input_loc, batch_size, iterations):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = glob.glob(data_path)

def loader():
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
if len(examples) == batch_size:
yield examples
del examples
examples = []

def loader_hf():
examples = []
for f1 in files:
examples.append(
InputExample(
image=f1["image"],
label=f1["label"],
)
)
if len(examples) == batch_size:
yield examples
del examples
examples = []

if len(files) == 0:
files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True))
files = []
sample_count = batch_size * iterations
for _ in range(sample_count):
files.append(next(files_raw))
del files_raw
return loader_hf()

return loader()


def get_data(input_loc):
img_dir = input_loc + "/"
data_path = os.path.join(img_dir, "*G")
files = sorted(glob.glob(data_path))
examples = []
for f1 in files:
examples.append(
InputExample(
image=get_input(f1),
label=get_label(f1),
)
)
image_examples = examples

return image_examples
Loading

0 comments on commit 8b63ee5

Please sign in to comment.