From c7068a2f0dc826844a07b239ed956e0b6a4b4330 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Wed, 3 Jan 2024 12:04:00 -0600 Subject: [PATCH] Adds an example for resnet50 in half precision --- examples/{resnet-18 => resnet}/README.md | 0 .../{resnet-18 => resnet}/requirements.txt | 0 examples/{resnet-18 => resnet}/resnet-18.py | 0 examples/resnet/resnet-50-fp16.py | 85 +++++++++++++++++++ 4 files changed, 85 insertions(+) rename examples/{resnet-18 => resnet}/README.md (100%) rename examples/{resnet-18 => resnet}/requirements.txt (100%) rename examples/{resnet-18 => resnet}/resnet-18.py (100%) create mode 100644 examples/resnet/resnet-50-fp16.py diff --git a/examples/resnet-18/README.md b/examples/resnet/README.md similarity index 100% rename from examples/resnet-18/README.md rename to examples/resnet/README.md diff --git a/examples/resnet-18/requirements.txt b/examples/resnet/requirements.txt similarity index 100% rename from examples/resnet-18/requirements.txt rename to examples/resnet/requirements.txt diff --git a/examples/resnet-18/resnet-18.py b/examples/resnet/resnet-18.py similarity index 100% rename from examples/resnet-18/resnet-18.py rename to examples/resnet/resnet-18.py diff --git a/examples/resnet/resnet-50-fp16.py b/examples/resnet/resnet-50-fp16.py new file mode 100644 index 000000000..8524099d7 --- /dev/null +++ b/examples/resnet/resnet-50-fp16.py @@ -0,0 +1,85 @@ +from torchvision.models import resnet50, ResNet50_Weights +import torch +import numpy as np +from shark_turbine.aot import * +import iree.runtime as rt + +# Loading feature extractor and pretrained model from huggingface +# extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-18") +model = resnet50(weights="DEFAULT") +float_model = model.eval().float() +model = model.eval().half() + + +# define a function to do inference +# this will get passed to the compiled module as a jittable function +def vision_forward(pixel_values_tensor: torch.Tensor): + with torch.no_grad(): + logits = model.forward(pixel_values_tensor) + predicted_id = torch.argmax(logits, -1) + return predicted_id + + +def vision_forward_float(pixel_values_tensor: torch.Tensor): + with torch.no_grad(): + logits = float_model.forward(pixel_values_tensor) + predicted_id = torch.argmax(logits, -1) + return predicted_id + + +# a dynamic module for doing inference +# this will be compiled AOT to a memory buffer +class Resnet50_f16(CompiledModule): + params = export_parameters(model) + + def forward(self, x=AbstractTensor(None, 3, 224, 224, dtype=torch.float16)): + # set a constraint for the dynamic number of batches + # interestingly enough, it doesn't seem to limit BATCH_SIZE + const = [x.dynamic_dim(0) < 16] + return jittable(vision_forward)(x, constraints=const) + + +# build an mlir module with 1-shot exporter +exported = export(Resnet50_f16) +# compile exported module to a memory buffer +compiled_binary = exported.compile(save_to=None) + + +# return type is rt.array_interop.DeviceArray +# np.array of outputs can be accessed via to_host() method +def shark_infer(x): + config = rt.Config("local-task") + vmm = rt.load_vm_module( + rt.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()), + config, + ) + y = vmm.forward(x) + return y + + +# prints the text corresponding to output label codes +def print_labels(class_id): + weights = ResNet50_Weights.DEFAULT + for l in class_id: + print(weights.meta["categories"][l]) + + +# finds discrepancies between id0 and id1 +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + return max_error + + +# load some examples and check for discrepancies between +# compiled module and standard inference (forward function) + +x = torch.randn((10, 3, 224, 224), dtype=torch.float16) +x_float = torch.randn((10, 3, 224, 224), dtype=torch.float32) +y0 = shark_infer(x).to_host() +float_model = float_model.float() +y1 = np.asarray(vision_forward_float(x_float)) +print_labels(y0) +print( + f"Largest error between turbine (fp16) and pytorch (fp32) baseline is {largest_error(y0,y1)}" +)