Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference time for tflite quantized model is high #980

Open
RubensZimbres opened this issue Jul 23, 2024 · 0 comments
Open

Inference time for tflite quantized model is high #980

RubensZimbres opened this issue Jul 23, 2024 · 0 comments

Comments

@RubensZimbres
Copy link

I followed a tutorial on MediaPipe and their model, https://storage.googleapis.com/mediapipe-models/image_classifier/efficientnet_lite0/float32/1/efficientnet_lite0.tflite, has inference time of milliseconds.

I used ai-edge-torch to convert a PyTorch efficientnet to tflite, but the inference time is 2-3 seconds.

Here's my code:

efficientnet = torchvision.models.efficientnet_b3(torchvision.models.EfficientNet_B3_Weights.IMAGENET1K_V1).eval()

class PermuteInput(nn.Module):
    def __init__(self):
        super(PermuteInput, self).__init__()

    def forward(self, x):
        # Permute from (batch, height, width, channels) to (batch, channels, height, width)
        return x.permute(0, 3, 1, 2)


import torch.nn.functional as F

class PermuteOutput(nn.Module):
    def __init__(self):
        super(PermuteOutput, self).__init__()

    def forward(self, x):
        return F.normalize(x)

efficientnet_with_reshape = nn.Sequential(
    PermuteInput(),
    efficientnet,
    PermuteOutput()
)


edge_model = efficientnet_with_reshape.eval()

sample_input = (torch.rand((1, 224, 224, 3), dtype=torch.float32),)

edge_model = ai_edge_torch.convert(edge_model.eval(), sample_input)

edge_model.export("/home/user/efficientnet.tflite")

# QUANTIZE TFLITE MODEL

pt2e_quantizer = PT2EQuantizer().set_global(
    get_symmetric_quantization_config(is_per_channel=True, is_dynamic=True)
)

pt2e_torch_model = capture_pre_autograd_graph(efficientnet_with_reshape.eval(),sample_input)
pt2e_torch_model = prepare_pt2e(pt2e_torch_model, pt2e_quantizer)

# Run the prepared model with sample input data to ensure that internal observers are populated with correct values
pt2e_torch_model(*sample_input)

# Convert the prepared model to a quantized model
pt2e_torch_model = convert_pt2e(pt2e_torch_model, fold_quantize=False)

# Convert to an ai_edge_torch model
pt2e_drq_model = ai_edge_torch.convert(pt2e_torch_model, sample_input, quant_config=QuantConfig(pt2e_quantizer=pt2e_quantizer))


pt2e_drq_model.export("/home/user/efficientnet_quantized.tflite")

I properly added metadata to tflite, labels and also added a CORS policy to the bucket.

Is this a quantization issue or a bucket bandwidth issue? Because with the supported model, the inference is really fast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant