Skip to content

Commit

Permalink
[Quantization] Fix bug with removing PropagatingQuantizer (#3008)
Browse files Browse the repository at this point in the history
### Changes

Add removing quanitzer from
`_pqs_after_weight_dependent_output_quantized_nodes` inside
`remove_propagating_quantizer()`

### Reason for changes

Fix bug

### Related tickets

154366

### Tests

Tested on failed scenario
  • Loading branch information
kshpv authored Oct 28, 2024
1 parent 4501233 commit dca2cad
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 43 deletions.
1 change: 1 addition & 0 deletions nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def remove_propagating_quantizer(
if prop_quantizer.unified_scale_type is not None:
gid = self._unified_scale_group_manager.get_group_id_by_propagating_quantizer_id(prop_quantizer.id)
self._unified_scale_group_manager.remove_from_group(gid, prop_quantizer)
self._pqs_after_weight_dependent_output_quantized_nodes.pop(prop_quantizer, None)

def propagate_quantizer_via_path(
self, prop_quantizer: PropagatingQuantizer, path: PropagationPath
Expand Down
111 changes: 68 additions & 43 deletions tests/common/quantization/test_quantizer_propagation_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,8 @@ def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
[pq_1, pq_2], [QuantizerConfig()], [None, None], InsertionPointGraph.get_post_hook_node_key("1 /B_0")
)
pq_3 = qpsg.add_propagating_quantizer(
[QuantizerConfig(per_channel=True)], InsertionPointGraph.get_pre_hook_node_key("4 /E_0") # sic!
[QuantizerConfig(per_channel=True)],
InsertionPointGraph.get_pre_hook_node_key("4 /E_0"), # sic!
)
# pq_3 should be considered redundant w.r.t the upstream per-tensor quantizer
paths = get_edge_paths_for_propagation(
Expand Down Expand Up @@ -1311,55 +1312,55 @@ def create_graph_for_output_quant_as_weights() -> NNCFGraph:
MODEL_GRAPH: NNCFGraph = create_graph_for_output_quant_as_weights()


class TestOutputQuantAsWeightsSetup:
class OutputQuantAsWeightsSetupTestStruct(ABC):
operator_node_key_vs_trait_dict: Dict[str, QuantizationTrait]
quantizable_module_node_names_vs_qconfigs: Dict[NNCFNodeName, List[QuantizerConfig]]
class OutputQuantAsWeightsSetupTestStruct(ABC):
operator_node_key_vs_trait_dict: Dict[str, QuantizationTrait]
quantizable_module_node_names_vs_qconfigs: Dict[NNCFNodeName, List[QuantizerConfig]]

def prepare_qpsg_state(self, qpsg: QPSG) -> QPSG:
qpsg = TestQuantizerPropagationStateGraph.mark_nodes_with_traits(qpsg, self.operator_node_key_vs_trait_dict)
return self._setup_and_propagate_quantizers(qpsg)
def prepare_qpsg_state(self, qpsg: QPSG) -> QPSG:
qpsg = TestQuantizerPropagationStateGraph.mark_nodes_with_traits(qpsg, self.operator_node_key_vs_trait_dict)
return self._setup_and_propagate_quantizers(qpsg)

@abstractmethod
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
pass
@abstractmethod
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
pass

@abstractmethod
def ref_quantizer_setup(self) -> MultiConfigQuantizerSetup:
pass
@abstractmethod
def ref_quantizer_setup(self) -> MultiConfigQuantizerSetup:
pass

class LinearPropagation(OutputQuantAsWeightsSetupTestStruct):
operator_node_key_vs_trait_dict = {
"5 F/F_0": QuantizationTrait.QUANTIZATION_AGNOSTIC,
"8 G/G_0": QuantizationTrait.INPUTS_QUANTIZABLE,
"2 C/C_0": QuantizationTrait.OUTPUT_QUANTIZATION_AS_WEIGHTS,
}
quantizable_module_node_names_vs_qconfigs = {"C/C_0": [QuantizerConfig()]}

def ref_quantizer_setup(self) -> MultiConfigQuantizerSetup:
setup = MultiConfigQuantizerSetup()
setup.quantization_points[0] = MultiConfigQuantizationPoint(
WeightQuantizationInsertionPoint(target_node_name="C/C_0"),
possible_qconfigs=[QuantizerConfig()],
directly_quantized_operator_node_names=["C/C_0", "G/G_0"],
)
setup.shared_input_operation_set_groups[0] = {0}
return setup
class LinearPropagation(OutputQuantAsWeightsSetupTestStruct):
operator_node_key_vs_trait_dict = {
"5 F/F_0": QuantizationTrait.QUANTIZATION_AGNOSTIC,
"8 G/G_0": QuantizationTrait.INPUTS_QUANTIZABLE,
"2 C/C_0": QuantizationTrait.OUTPUT_QUANTIZATION_AS_WEIGHTS,
}
quantizable_module_node_names_vs_qconfigs = {"C/C_0": [QuantizerConfig()]}

def ref_quantizer_setup(self) -> MultiConfigQuantizerSetup:
setup = MultiConfigQuantizerSetup()
setup.quantization_points[0] = MultiConfigQuantizationPoint(
WeightQuantizationInsertionPoint(target_node_name="C/C_0"),
possible_qconfigs=[QuantizerConfig()],
directly_quantized_operator_node_names=["C/C_0", "G/G_0"],
)
setup.shared_input_operation_set_groups[0] = {0}
return setup

def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key("6 G/G_0"))
paths = get_edge_paths_for_propagation(
qpsg,
InsertionPointGraph.get_post_hook_node_key("2 C/C_0"),
InsertionPointGraph.get_pre_hook_node_key("6 G/G_0"),
)
path = paths[0]
qpsg.propagate_quantizer_via_path(pq_1, path)
qpsg.mark_act_quantizer_as_dependent_on_weights(pq_1, "2 C/C_0")
return qpsg

def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
pq_1 = qpsg.add_propagating_quantizer(
[QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key("6 G/G_0")
)
paths = get_edge_paths_for_propagation(
qpsg,
InsertionPointGraph.get_post_hook_node_key("2 C/C_0"),
InsertionPointGraph.get_pre_hook_node_key("6 G/G_0"),
)
path = paths[0]
qpsg.propagate_quantizer_via_path(pq_1, path)
qpsg.mark_act_quantizer_as_dependent_on_weights(pq_1, "2 C/C_0")
return qpsg

class TestOutputQuantAsWeightsSetup:
class LinearPropagationWithConfigSubspaceSelection(OutputQuantAsWeightsSetupTestStruct):
operator_node_key_vs_trait_dict = {
"4 D/D_0": QuantizationTrait.QUANTIZATION_AGNOSTIC,
Expand Down Expand Up @@ -1835,3 +1836,27 @@ def test_create_quantizer_setup_with_output_quant_as_weights_ops(
def test_get_weight_and_activation_qconfig_list_intersection(weight_configs, activation_configs, reference_configs):
resulted_configs = QPSG._get_weight_and_activation_qconfig_list_intersection(weight_configs, activation_configs)
assert resulted_configs == reference_configs


class LinearPropagationForRemovalTest(LinearPropagation):
def _setup_and_propagate_quantizers(self, qpsg: QPSG) -> QPSG:
pq_1 = qpsg.add_propagating_quantizer([QuantizerConfig()], InsertionPointGraph.get_pre_hook_node_key("6 G/G_0"))
paths = get_edge_paths_for_propagation(
qpsg,
InsertionPointGraph.get_post_hook_node_key("2 C/C_0"),
InsertionPointGraph.get_pre_hook_node_key("6 G/G_0"),
)
path = paths[0]
qpsg.propagate_quantizer_via_path(pq_1, path)
qpsg.mark_act_quantizer_as_dependent_on_weights(pq_1, "2 C/C_0")
return qpsg, pq_1


def test_remove_pq_from_pqs_after_weight_dependent_output_quantized_nodes():
ip_graph = get_ip_graph_for_test(MODEL_GRAPH)
quant_prop_graph = QPSG(ip_graph)
linear_propagation = LinearPropagationForRemovalTest()
quant_prop_graph, pq_1 = linear_propagation.prepare_qpsg_state(quant_prop_graph)
qpsg, pq_1 = linear_propagation._setup_and_propagate_quantizers(quant_prop_graph)
qpsg.remove_propagating_quantizer(pq_1)
assert pq_1 not in qpsg._pqs_after_weight_dependent_output_quantized_nodes

0 comments on commit dca2cad

Please sign in to comment.