Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concat scales not being grouped #2195

Open
basioli-k opened this issue Oct 13, 2023 · 3 comments
Open

Concat scales not being grouped #2195

basioli-k opened this issue Oct 13, 2023 · 3 comments
Assignees

Comments

@basioli-k
Copy link

I am trying to quantize a pytorch model using NNCF.
The output of my model is a concatenation of two tensors.

To quantize my outputs I set:
advanced_parameters = AdvancedQuantizationParameters(quantize_outputs=True)

When I quantize the model I get a separate quantizer for each input:

ModuleDict(
  (/nncf_model_input_0|OUTPUT): AsymmetricQuantizer(bit=8, ch=False)
  (/nncf_model_input_1|OUTPUT): AsymmetricQuantizer(bit=8, ch=False)
)

Based on what I saw in NNCF I would expect to get something like this.

ModuleDict(
  (/nncf_model_input_0|OUTPUT;/nncf_model_input_1|OUTPUT): AsymmetricQuantizer(bit=8, ch=False)
)

I am guessing it's an edge case which comes up due to AdvancedQuantizationParameters.

NNCF version: 2.6.0

Run the following to reproduce:

import nncf
import torch
import numpy as np
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters

class DummyDataset(torch.utils.data.Dataset):
    """ Loads images from a folder """
    def __init__(self, input_shapes, input_names):
        self.input_shapes = input_shapes
        self.input_names = input_names

    def __len__(self):
        return 1

    def __getitem__(self, index):
        return { self.input_names[i]: np.random.rand(*input_shape) for i, input_shape in enumerate(self.input_shapes) }

class DummyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(6, 6, 3, 1, 1)
        
    def forward(self, x, y):
        x_cat_y = torch.cat((x,y), dim=1)
        
        return x_cat_y
        # return self.conv(x_cat_y) # use this to verify that quantizers get grouped if concat isn't the output
    
def quantize_model(model, input_shapes, input_names):
    def transform_fn(data):
        data_dict = data
        return tuple(data_dict[key][0].to(torch.float32) for key in data_dict)
    
    dummy_dl = torch.utils.data.DataLoader(DummyDataset(input_shapes, input_names))
    calibration_dataset = nncf.Dataset(dummy_dl, transform_fn)
    
    advanced_parameters = AdvancedQuantizationParameters(quantize_outputs=True)
    
    return nncf.quantize(model, calibration_dataset, subset_size=1, preset=nncf.QuantizationPreset.MIXED, advanced_parameters=advanced_parameters)

def main():
    quantized_model = quantize_model(DummyModel(), [(1, 3, 256, 256), (1, 3, 256, 256)], ["x", "y"])
    
    print()
    print(quantized_model._nncf.external_quantizers)
    
if __name__ == "__main__":
    main()
@vshampor vshampor self-assigned this Oct 13, 2023
@vshampor
Copy link
Contributor

Greetings, @basioli-k! Thanks for spotting this, and for the detailed reproducer that makes debugging this a breeze.

The unexpected behaviour seems to be due to some logic introduced in #1778. If I comment the following lines:

if scales_unification_map is not None and metatype in scales_unification_map:
unify_conditions.append(followed_by_weighted_types(curr_node_key, metatype))

the input quantizers in both your cases get unified. We added that logic (only unifying concat scales if concat is followed by a weighted op) in response to low PTQ accuracy in densenet and inception, but IMO the concat input quantizers in the per-tensor case should be unified regardless of the ops that follow the concat. Will investigate how to best fix this on the develop branch.

@basioli-k
Copy link
Author

Thank you for the response.
Looking forward to the fix 😃

@avitial
Copy link

avitial commented Apr 16, 2024

Ref. 138683

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants