Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model conversion with ai-edge-torch + Matadata presents long inference time #982

Open
RubensZimbres opened this issue Jul 24, 2024 · 0 comments

Comments

@RubensZimbres
Copy link

RubensZimbres commented Jul 24, 2024

I developed a customized Resnet in PyTorch, converted with ai-edge-torch to tflite, added metadata and put the tflite model in a Google Cloud bucket for inference, in a MediaPipe example of Image Classification. Explanation and code here.

resnet50 = torchvision.models.resnet50(torchvision.models.ResNet50_Weights.IMAGENET1K_V1).eval()

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2)

class HandleOutput(nn.Module):
    def __init__(self):
        super(HandleOutput, self).__init__()

    def forward(self, x):
        return F.normalize(x) 

# Add the custom reshape layer to the model
# Here, we use a Sequential container to append the reshape layer after the adaptive average pooling layer
resnet50_with_reshape = nn.Sequential(
    PermuteInput(),
    resnet50,
    HandleOutput()
)

edge_model = resnet50_with_reshape.eval()
sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),)

edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input)

pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True))

pt2e_torch_model = capture_pre_autograd_graph(resnet50_with_reshape.eval(),sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))

However, when the default (and supported model) is used, inference time is milliseconds. With the customized tflite model, inference time is 2 seconds.

The customized model was developed and converted in a Tensorflow 2.17.0. However, the tflite library only works successfully wit Tensorflow 2.13.0. I am using the latest version of tflite-support in an Anaconda enviroment on Ubuntu 22.04.

When I add metadata, the only Tensorflow version that runs is 2.13.0. This difference in Tensorflow versions and age is generating an issue: the "Created TensorFlow Lite XNNPACK delegate for CPU." takes time and does not do inference in real-time in the browser, differently from the supported tflite at:

https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite

With my model in the web page, "Graph successfully started running", but inference takes too long.

By using tflite-support 0.4.4 I get this error:

ImportError: generic_type: cannot initialize type "StatusCode": an object with that name is already defined

It only works with tflite-support==0.1.0a1, but then the inference time is too long and I get this error in Debug Console:

W0724 16:08:28.765000 1880752 inference_feedback_manager.cc:121] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.

I noticed that I successfuly create a tflite model with a serving_default signature. However, when I add metadata, this signature disappears. Here's the code I'm using to add metadata:

from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb
import os
import tensorflow as tf

model_file="efficientnet_quantized.tflite"
os.chdir("/folder")


"""Creates the metadata for an image classifier."""

# Creates model info.
model_meta = _metadata_fb.ModelMetadataT()
model_meta.name = "Resnet image classifier"
model_meta.description = ("Identify the most prominent object in the "
                          "image from a set of 1,000 categories such as "
                          "trees, animals, food, vehicles, person etc.")
model_meta.version = "v1"
model_meta.author = "Rubens Zimbres"
model_meta.license = ("Apache License. Version 2.0 "
                      "http://www.apache.org/licenses/LICENSE-2.0.")

# Creates input info.
input_meta = _metadata_fb.TensorMetadataT()

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()

input_meta.name = "Image"
input_meta.description = (
    "Input image to be classified. The expected image is {0} x {1}, with "
    "three channels (red, blue, and green) per pixel. Each value in the "
    "tensor is a single byte between 0 and 255.".format(244, 244))
input_meta.content = _metadata_fb.ContentT()
input_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
input_meta.content.contentProperties.colorSpace = (
    _metadata_fb.ColorSpaceType.RGB)
input_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.ImageProperties)
input_normalization = _metadata_fb.ProcessUnitT()
input_normalization.optionsType = (
    _metadata_fb.ProcessUnitOptions.NormalizationOptions)
input_normalization.options = _metadata_fb.NormalizationOptionsT()
input_normalization.options.mean = [127.5]
input_normalization.options.std = [127.5]
input_meta.processUnits = [input_normalization]
input_stats = _metadata_fb.StatsT()
input_stats.max = [1]
input_stats.min = [0]
input_stats.width = [224]
input_stats.height = [224]
input_stats.num_classes = [1000]
input_meta.stats = input_stats

# Creates output info.
output_meta = _metadata_fb.TensorMetadataT()
output_meta.name = "probability"
output_meta.description = "Probabilities of the 1000 labels respectively."
output_meta.content = _metadata_fb.ContentT()
output_meta.content.content_properties = _metadata_fb.FeaturePropertiesT()
output_meta.content.contentPropertiesType = (
    _metadata_fb.ContentProperties.FeatureProperties)
output_stats = _metadata_fb.StatsT()
output_stats.max = [1.0]
output_stats.min = [0.0]
output_meta.stats = output_stats
label_file = _metadata_fb.AssociatedFileT()
label_file.name = os.path.basename("resnet_labels2.txt")
label_file.description = "Labels for objects that the model can recognize."
label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
output_meta.associatedFiles = [label_file]

# Creates subgraph info.
subgraph = _metadata_fb.SubGraphMetadataT()
subgraph.inputTensorMetadata = [input_meta]
subgraph.outputTensorMetadata = [output_meta]
model_meta.subgraphMetadata = [subgraph]

b = flatbuffers.Builder(0)
b.Finish(
    model_meta.Pack(b),
    _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
metadata_buf = b.Output()

populator = _metadata.MetadataPopulator.with_model_file(model_file)
populator.load_metadata_buffer(metadata_buf)
populator.load_associated_files(["resnet_labels2.txt"])
populator.populate()

displayer = _metadata.MetadataDisplayer.with_model_file('/folder/efficientnet_quantized.tflite')
export_json_file = "/folder/metadata.json"
json_file = displayer.get_metadata_json()
# Optional: write out the metadata as a json file
with open(export_json_file, "w") as f:
  f.write(json_file)

Any ideas no how to solve this issue ?

UPDATE

I found the problem: The tflite quantization is generating another signature and then the web page debugger shows:

Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.

This makes the XNNPACK to take longer to load and delays inference time

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant