From cf228d92f7bb51db83eb924069117d9a0227eb33 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 16 Dec 2024 19:10:03 +0100 Subject: [PATCH] Fix import --- .../torch/sparsify_activations/torch_backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nncf/experimental/torch/sparsify_activations/torch_backend.py b/nncf/experimental/torch/sparsify_activations/torch_backend.py index 388bc52ecae..87478db0ede 100644 --- a/nncf/experimental/torch/sparsify_activations/torch_backend.py +++ b/nncf/experimental/torch/sparsify_activations/torch_backend.py @@ -20,6 +20,7 @@ from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -27,7 +28,6 @@ from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.tensor_statistics.collectors import PTAbsQuantileReducer ACTIVATIONS_SPARSIFIER_PREFIX = "activations_sparsifier" @@ -61,8 +61,8 @@ class PTSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend): def supported_metatypes(self) -> List[Type[OperatorMetatype]]: return [om.PTLinearMetatype] - def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> PTAbsQuantileReducer: - return PTAbsQuantileReducer(quantile=quantile) + def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> AbsQuantileReducer: + return AbsQuantileReducer(quantile=quantile) def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: return PTTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, input_port_id=port_id)