From ef9cdd29825a22ae5916c0c0aa49219e7030f0df Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 28 Oct 2024 13:36:27 +0400 Subject: [PATCH] [Torch] INT4 weight compression (#3014) ### Changes - Support INT4 weight compression in Torch and Torch.FX backends - Added `INT4SymmetricWeightsDecompressor` and `INT4ASymmetricWeightsDecompressor` ### Reason for changes Support INT4 weight model compression of PyTorch models in NNCF. ### Related tickets https://github.com/openvinotoolkit/nncf/issues/3005 ### Tests updated tests --- .../weight_compression/torch_backend.py | 54 +++--- .../weight_compression/torch_fx_backend.py | 56 +++--- nncf/quantization/quantize_model.py | 47 +++-- nncf/torch/quantization/layers.py | 166 +++++++++++++++++- nncf/torch/quantization/quantize_functions.py | 61 +++++++ .../data/wc_reference_data_2024.5.yaml | 4 + .../sparsify_activations/pipelines.py | 6 +- tests/post_training/model_scope.py | 10 ++ .../test_quantize_conformance.py | 2 +- ...ma_int8_sym_weights_sparse_activations.dot | 160 ++++++++--------- ...ar_int8_sym_weights_sparse_activations.dot | 10 +- ...e1_int8_sym_weights_sparse_activations.dot | 40 ++--- ...ar_int8_sym_weights_sparse_activations.dot | 40 ++--- tests/torch/fx/test_compress_weights.py | 64 +++++-- tests/torch/ptq/test_weights_compression.py | 105 ++++++++--- 15 files changed, 586 insertions(+), 239 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 6bfa6432748..52fae7531ec 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -43,8 +43,10 @@ from nncf.torch.model_graph_manager import split_const_name from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor class PTWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): @@ -212,10 +214,9 @@ def transform_model( for wc_params in weight_compression_parameters: compression_config = wc_params.compression_config - if compression_config.mode not in [ - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT8, + if compression_config.mode in [ + CompressWeightsMode.NF4, + CompressWeightsMode.E2M1, ]: raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") @@ -235,17 +236,35 @@ def transform_model( None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name), None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name), ) - compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16) - # pack compressed tensor + # creates weight decompressor if compression_config.mode == CompressWeightsMode.INT8_SYM: - dtype = TensorDataType.int8 - else: - dtype = TensorDataType.uint8 - packed_tensor = compressed_weight.tensor.astype(dtype) + decompressor = INT8SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype) + elif compression_config.mode == CompressWeightsMode.INT8_ASYM: + decompressor = INT8AsymmetricWeightsDecompressor( + compressed_weight.scale.data, compressed_weight.zero_point.data, result_dtype=weight.dtype + ) + elif compression_config.mode == CompressWeightsMode.INT4_SYM: + decompressor = INT4SymmetricWeightsDecompressor( + scale=compressed_weight.scale.data, + compressed_weight_shape=compressed_weight.tensor.shape, + result_shape=weight.shape, + result_dtype=weight.dtype, + ) + elif compression_config.mode == CompressWeightsMode.INT4_ASYM: + decompressor = INT4AsymmetricWeightsDecompressor( + scale=compressed_weight.scale.data, + zero_point=compressed_weight.zero_point.data, + compressed_weight_shape=compressed_weight.tensor.shape, + result_shape=weight.shape, + result_dtype=weight.dtype, + ) + + # pack tensor + packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data) # sets compressed tensor - compressed_parameter = torch.nn.Parameter(packed_tensor.data, requires_grad=False) + compressed_parameter = torch.nn.Parameter(packed_tensor, requires_grad=False) setattr(module, weight_attr_name, compressed_parameter) consumer_nodes = graph.get_next_nodes(weight_node) @@ -256,15 +275,6 @@ def transform_model( if id(param) == id(weight): setattr(c_module, name, compressed_parameter) - # creates weight decompressor - if compression_config.mode == CompressWeightsMode.INT8_SYM: - decompressor = SymmetricWeightsDecompressor(compressed_weight.scale.data, result_dtype=weight.dtype) - else: - packed_zero_point = compressed_weight.zero_point.astype(dtype) - decompressor = AsymmetricWeightsDecompressor( - compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.dtype - ) - # registry weight decompression module in the model decompressor_name = f"weights_decompressor_{weight_node.node_name.replace('.', '_')}" diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index d9b6c70b7a7..2816f82a6e2 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -45,8 +45,10 @@ from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.model_graph_manager import get_const_node from nncf.torch.model_graph_manager import get_weight_tensor_port_ids -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor class FXWeightCompressionAlgoBackend(WeightCompressionAlgoBackend): @@ -176,10 +178,9 @@ def transform_model( for wc_params in weight_compression_parameters: compression_config = wc_params.compression_config - if compression_config.mode not in [ - CompressWeightsMode.INT8_ASYM, - CompressWeightsMode.INT8_SYM, - CompressWeightsMode.INT8, + if compression_config.mode in [ + CompressWeightsMode.NF4, + CompressWeightsMode.E2M1, ]: raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) @@ -196,35 +197,44 @@ def transform_model( None if precomputed_scales is None else precomputed_scales.get(wc_params.weight_name), None if precomputed_zero_points is None else precomputed_zero_points.get(wc_params.weight_name), ) - compressed_weight.scale = compressed_weight.scale.astype(dtype=TensorDataType.float16) - # pack compressed tensor - if compression_config.mode == CompressWeightsMode.INT8_SYM: - dtype = TensorDataType.int8 - else: - dtype = TensorDataType.uint8 - packed_tensor = compressed_weight.tensor.astype(dtype) - - self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor) # creates weight decompressor if compression_config.mode == CompressWeightsMode.INT8_SYM: - decompressor = SymmetricWeightsDecompressor( + decompressor = INT8SymmetricWeightsDecompressor( compressed_weight.scale.data, result_dtype=weight.data.dtype ) - decompressor_type = "symmetric" - else: - packed_zero_point = compressed_weight.zero_point.astype(dtype) - decompressor = AsymmetricWeightsDecompressor( - compressed_weight.scale.data, packed_zero_point.data, result_dtype=weight.data.dtype + elif compression_config.mode == CompressWeightsMode.INT8_ASYM: + decompressor = INT8AsymmetricWeightsDecompressor( + compressed_weight.scale.data, compressed_weight.zero_point.data, result_dtype=weight.data.dtype + ) + elif compression_config.mode == CompressWeightsMode.INT4_SYM: + decompressor = INT4SymmetricWeightsDecompressor( + scale=compressed_weight.scale.data, + compressed_weight_shape=compressed_weight.tensor.shape, + result_shape=weight.shape, + result_dtype=weight.data.dtype, ) - decompressor_type = "asymmetric" + elif compression_config.mode == CompressWeightsMode.INT4_ASYM: + decompressor = INT4AsymmetricWeightsDecompressor( + scale=compressed_weight.scale.data, + zero_point=compressed_weight.zero_point.data, + compressed_weight_shape=compressed_weight.tensor.shape, + result_shape=weight.shape, + result_dtype=weight.data.dtype, + ) + + # pack tensor + packed_tensor = decompressor.pack_weight(compressed_weight.tensor.data) + + # sets compressed tensor + self.set_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph, packed_tensor) # register weight decompression module in the model graph_weight_node = get_graph_node_by_name(model.graph, wc_params.node_with_weight.node_name) compressed_weight_name = graph_weight_node.all_input_nodes[wc_params.weight_port_id].name decompressor_suffix = "_".join(compressed_weight_name.replace(".", "_").split("_")[:-2]) - decompressor_name = f"{decompressor_type}_weights_decompressor_{decompressor_suffix}" + decompressor_name = f"{decompressor.quantization_mode}_weights_decompressor_{decompressor_suffix}" # inserts the weight decompressor into the model as the post hook on the model weight transformation_layout.register( diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index f475d4e52ee..52d5bf07edb 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -505,20 +505,26 @@ def compress_weights( from nncf.torch.model_creation import wrap_model from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl - if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: + if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: raise nncf.ParameterNotSupportedError( - "Torch backend supports only INT8_ASYM, INT8_SYM modes for weight compression, " - f"but given {mode.value} mode." + "Torch backend does not support NF4 and E2M1 modes for weight compression." ) - if True in [awq, scale_estimation, gptq, lora_correction]: + options = { + "sensitivity_metric": sensitivity_metric, + "awq": awq, + "scale_estimation": scale_estimation, + "gptq": gptq, + "lora_correction": lora_correction, + } + unsupported_options = [name for name, value in options.items() if value is not None] + if unsupported_options: raise nncf.ParameterNotSupportedError( - "Torch backend does not support 'awq', 'scale_estimation', 'gptq' and 'lora_correction' options. " - "Set them to None." + f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) - if backup_mode is not None: - raise nncf.ParameterNotSupportedError("Torch backend does not support backup_mode option.") + if ratio is not None and ratio != 1: + raise nncf.ParameterNotSupportedError("Torch backend does not support ratio != 1.") if is_wrapped_model(model): if not model.nncf.trace_parameters: @@ -541,20 +547,27 @@ def compress_weights( compress_weights_impl as fx_compression_weights_impl, ) - if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: + if mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]: raise nncf.ParameterNotSupportedError( - "TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, " - f"but given {mode.value} mode." + "Torch backend does not support NF4 and E2M1 modes for weight compression." ) - if backup_mode is not None: - raise nncf.ParameterNotSupportedError("TorchFX backend does not support backup_mode option.") - - if any((awq, scale_estimation, gptq, lora_correction)): + options = { + "sensitivity_metric": sensitivity_metric, + "awq": awq, + "scale_estimation": scale_estimation, + "gptq": gptq, + "lora_correction": lora_correction, + } + unsupported_options = [name for name, value in options.items() if value is not None] + if unsupported_options: raise nncf.ParameterNotSupportedError( - "TorchFX backend does not support 'awq', 'scale_estimation', 'gptq'," - "and 'lora_correction' options. Set them to None." + f"TorchFX backend does not support {', '.join(unsupported_options)} option(s). Set them to None." ) + + if ratio is not None and ratio != 1: + raise nncf.ParameterNotSupportedError("TorchFX backend does not support ratio != 1.") + if dataset: raise nncf.ParameterNotSupportedError( "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" diff --git a/nncf/torch/quantization/layers.py b/nncf/torch/quantization/layers.py index b84d325526f..85e41428846 100644 --- a/nncf/torch/quantization/layers.py +++ b/nncf/torch/quantization/layers.py @@ -49,7 +49,11 @@ from nncf.torch.quantization.quantize_functions import decompress_asymmetric from nncf.torch.quantization.quantize_functions import decompress_symmetric from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high +from nncf.torch.quantization.quantize_functions import pack_int4 +from nncf.torch.quantization.quantize_functions import pack_uint4 from nncf.torch.quantization.quantize_functions import symmetric_quantize +from nncf.torch.quantization.quantize_functions import unpack_int4 +from nncf.torch.quantization.quantize_functions import unpack_uint4 from nncf.torch.return_types import maybe_get_values_from_torch_return_type from nncf.torch.return_types import maybe_wrap_to_torch_return_type from nncf.torch.utils import get_flat_tensor_contents_string @@ -1045,43 +1049,189 @@ def get_scale_shape(input_shape: List[int], is_weights: bool, per_channel: bool, return get_per_channel_scale_shape(input_shape, is_weights, channel_idx) -class AsymmetricWeightsDecompressor(nn.Module): +class BaseWeightsDecompressor(nn.Module, ABC): + """ + Base class for implementing weights decompression modules within NNCF. + + This class is intended to serve as the foundation for modules that handle the decompression + of quantized model weights. It provides an interface for defining the quantization mode and + packing the weights according to the specified quantization strategy. Classes inheriting from + this base class must implement the abstract methods for packing and handling the quantization mode. + """ + + @property + @abstractmethod + def quantization_mode(self) -> QuantizationMode: + """ + Property that specifies the quantization mode used for compressing weights. + + This method must be implemented to return the specific mode of quantization that + the decompressor is using, such as symmetric or asymmetric quantization. + + :return: The quantization mode as an instance of `QuantizationMode`. + """ + + @abstractmethod + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + """ + Pack the given weight tensor according to the selected quantization mode. + + :param weight: The tensor containing the weight values to be packed. + :return: The packed tensor. + """ + + +class INT8AsymmetricWeightsDecompressor(BaseWeightsDecompressor): """ Applies asymmetric decompression of compressed weights in the forward pass """ - def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: torch.dtype = None): + def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype: Optional[torch.dtype] = None): """ :param scale: A scale in quantization scheme :param zero_point: A zero point in quantization scheme :param result_dtype: (Optional) A data type that result should be cast to """ super().__init__() - self.register_buffer("_scale", scale) - self.register_buffer("_zero_point", zero_point) + self.register_buffer("_scale", scale.type(dtype=torch.float16)) + self.register_buffer("_zero_point", self.pack_weight(zero_point)) self.result_dtype = result_dtype - def forward(self, x): + @property + def quantization_mode(self) -> QuantizationMode: + return QuantizationMode.ASYMMETRIC + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(weight): + raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + if torch.any((weight < 0) | (weight > 255)): + raise ValueError("Weight values are not in [0, 255].") + return weight.type(dtype=torch.uint8) + + def forward(self, x) -> torch.Tensor: result = decompress_asymmetric(x, self._scale, self._zero_point) result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result -class SymmetricWeightsDecompressor(nn.Module): +class INT8SymmetricWeightsDecompressor(BaseWeightsDecompressor): """ Applies symmetric decompression of compressed weights in the forward pass """ - def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None): + def __init__(self, scale: torch.Tensor, result_dtype: Optional[torch.dtype] = None): """ :param scale: A scale in quantization scheme :param result_dtype: (Optional) A data type that result should be cast to """ super().__init__() - self.register_buffer("_scale", scale) + self.register_buffer("_scale", scale.type(dtype=torch.float16)) self.result_dtype = result_dtype + @property + def quantization_mode(self) -> QuantizationMode: + return QuantizationMode.SYMMETRIC + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(weight): + raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + if torch.any((weight < -128) | (weight > 127)): + raise ValueError("Weight values are not in [-128, 127].") + return weight.type(dtype=torch.int8) + + def forward(self, x) -> torch.Tensor: + result = decompress_symmetric(x, self._scale) + result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result + return result + + +class INT4AsymmetricWeightsDecompressor(BaseWeightsDecompressor): + def __init__( + self, + scale: torch.Tensor, + zero_point: torch.Tensor, + compressed_weight_shape: Tuple[int, ...], + result_shape: Optional[Tuple[int, ...]] = None, + result_dtype: Optional[torch.dtype] = None, + ): + """ + :param scale: A scale in quantization scheme + :param zero_point: A zero point in quantization scheme + :param compressed_weight_shape: A compressed weight shape + :param result_shape: (Optional) A shape that result should be reshaped + :param result_dtype: (Optional) A data type that result should be cast to + """ + super().__init__() + self.register_buffer("_scale", scale.type(dtype=torch.float16)) + + self.zero_point_shape = zero_point.shape + self.register_buffer("_zero_point", self.pack_weight(zero_point)) + + self.compressed_weight_shape = compressed_weight_shape + self.result_shape = result_shape + self.result_dtype = result_dtype + + @property + def quantization_mode(self) -> QuantizationMode: + return QuantizationMode.ASYMMETRIC + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(weight): + raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + if torch.any((weight < 0) | (weight > 15)): + raise ValueError("Weight values are not in [0, 15].") + return pack_uint4(weight.type(dtype=torch.uint8)) + def forward(self, x): + x = unpack_uint4(x) + x = x.reshape(self.compressed_weight_shape) + + zero_point = unpack_uint4(self._zero_point) + zero_point = zero_point.reshape(self.zero_point_shape) + + result = decompress_asymmetric(x, self._scale, zero_point) + result = result.reshape(self.result_shape) if self.result_shape is not None else result + result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result + return result + + +class INT4SymmetricWeightsDecompressor(BaseWeightsDecompressor): + def __init__( + self, + scale: torch.Tensor, + compressed_weight_shape: Tuple[int, ...], + result_shape: Optional[Tuple[int, ...]] = None, + result_dtype: Optional[torch.dtype] = None, + ): + """ + :param scale: A scale in quantization scheme + :param compressed_weight_shape: A compressed weight shape + :param result_shape: (Optional) A shape that result should be reshaped + :param result_dtype: (Optional) A data type that result should be cast to + """ + super().__init__() + self.register_buffer("_scale", scale.type(dtype=torch.float16)) + + self.compressed_weight_shape = compressed_weight_shape + self.result_shape = result_shape + self.result_dtype = result_dtype + + @property + def quantization_mode(self) -> QuantizationMode: + return QuantizationMode.SYMMETRIC + + def pack_weight(self, weight: torch.Tensor) -> torch.Tensor: + if torch.is_floating_point(weight): + raise ValueError(f"Invalid weight dtype {weight.type}. Integer types are supported.") + if torch.any((weight < -8) | (weight > 7)): + raise ValueError("Tensor values are not in [-8, 7].") + return pack_int4(weight.type(dtype=torch.int8)) + + def forward(self, x): + x = unpack_int4(x) + x = x.reshape(self.compressed_weight_shape) + result = decompress_symmetric(x, self._scale) + result = result.reshape(self.result_shape) if self.result_shape is not None else result result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result return result diff --git a/nncf/torch/quantization/quantize_functions.py b/nncf/torch/quantization/quantize_functions.py index 9b4055c4586..b967bb57683 100644 --- a/nncf/torch/quantization/quantize_functions.py +++ b/nncf/torch/quantization/quantize_functions.py @@ -13,6 +13,7 @@ import torch from nncf.common.logging import nncf_logger +from nncf.errors import ValidationError from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.functions import STRound from nncf.torch.functions import clamp @@ -259,6 +260,7 @@ def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: :return: The decompressed tensor """ input = input.type(dtype=scale.dtype) + zero_point = zero_point.type(dtype=scale.dtype) decompressed_input = (input - zero_point) * scale return decompressed_input @@ -275,3 +277,62 @@ def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor) -> torch.Tens input = input.type(dtype=scale.dtype) decompressed_input = input * scale return decompressed_input + + +def pack_uint4(tensor: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor containing uint4 values (in the range [0, 15]) into a tensor with uint8 values, + where each element stores two uint4 values. + + :param tensor: A tensor of dtype `torch.uint8` where each element represents a uint4 value. + The tensor should contain values in the range [0, 15]. + :return: A packed tensor of dtype `torch.uint8` where each element packs two uint4 values. + :raises nncf.errors.ValidationError: If the input tensor is not of type `torch.uint8`. + """ + if tensor.dtype != torch.uint8: + raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.uint8 type is supported.") + packed_tensor = tensor.contiguous() + packed_tensor = packed_tensor.reshape(-1, 2) + packed_tensor = torch.bitwise_and(packed_tensor[..., ::2], 15) | packed_tensor[..., 1::2] << 4 + return packed_tensor + + +@register_operator() +def unpack_uint4(packed_tensor: torch.Tensor) -> torch.Tensor: + """ + Unpacks a tensor, where each uint8 element stores two uint4 values, back into a tensor with + individual uint4 values. + + :param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two uint4 values. + :return: A tensor of dtype `torch.uint8` where each element represents a uint4 value. + """ + return torch.stack((torch.bitwise_and(packed_tensor, 15), torch.bitwise_right_shift(packed_tensor, 4)), dim=-1) + + +def pack_int4(tensor: torch.Tensor) -> torch.Tensor: + """ + Packs a tensor containing int4 values (in the range [-8, 7]) into a tensor with uint8 values, + where each element stores two int4 values. + + :param tensor: A tensor of dtype `torch.int8` where each element represents an int4 value. + The tensor should contain values in the range [-8, 7]. + :return: A packed tensor of dtype `torch.uint8` where each element packs two int4 values. + :raises nncf.errors.ValidationError: If the input tensor is not of type `torch.int8`. + """ + if tensor.dtype != torch.int8: + raise ValidationError(f"Invalid tensor dtype {tensor.type}. torch.int8 type is supported.") + tensor = tensor + 8 + return pack_uint4(tensor.type(torch.uint8)) + + +@register_operator() +def unpack_int4(packed_tensor: torch.Tensor) -> torch.Tensor: + """ + Unpacks a tensor, where each uint8 element stores two int4 values, back into a tensor with + individual int4 values. + + :param packed_tensor: A tensor of dtype `torch.uint8` where each element packs two int4 values. + :return: A tensor of dtype `torch.int8` where each element represents an int4 value. + """ + t = unpack_uint4(packed_tensor) + return t.type(torch.int8) - 8 diff --git a/tests/post_training/data/wc_reference_data_2024.5.yaml b/tests/post_training/data/wc_reference_data_2024.5.yaml index c1cddd4ea3e..ee5b1ffdad4 100644 --- a/tests/post_training/data/wc_reference_data_2024.5.yaml +++ b/tests/post_training/data/wc_reference_data_2024.5.yaml @@ -7,3 +7,7 @@ tinyllama_NF4_scale_estimation_stateful_per_channel_backend_OV: metric_value: 0.87132 num_int4: 11 num_int8: 290 +tinyllama_int4_data_free_backend_TORCH: + metric_value: 0.73541 + num_int4: 308 + num_int8: 4 diff --git a/tests/post_training/experimental/sparsify_activations/pipelines.py b/tests/post_training/experimental/sparsify_activations/pipelines.py index ef2ecbd1847..cb12f3b08f2 100644 --- a/tests/post_training/experimental/sparsify_activations/pipelines.py +++ b/tests/post_training/experimental/sparsify_activations/pipelines.py @@ -30,8 +30,8 @@ from nncf.experimental.torch.sparsify_activations import sparsify_activations from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor from tests.post_training.pipelines.base import LIMIT_LENGTH_OF_STATUS from tests.post_training.pipelines.base import PT_BACKENDS from tests.post_training.pipelines.base import BackendType @@ -267,7 +267,7 @@ def save_compressed_model(self): if self.backend == BackendType.CUDA_TORCH: self.model_hf.float() for module in self.model_hf.nncf.modules(): - if isinstance(module, (AsymmetricWeightsDecompressor, SymmetricWeightsDecompressor)): + if isinstance(module, (INT8AsymmetricWeightsDecompressor, INT8SymmetricWeightsDecompressor)): module.result_dtype = torch.float32 export_from_model( self.model_hf, self.output_model_dir, stateful=False, compression_option="fp32", device="cuda" diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 5f5080fee77..7ee651723af 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -441,6 +441,16 @@ }, "backends": [BackendType.TORCH], }, + { + "reported_name": "tinyllama_int4_data_free", + "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", + "pipeline_cls": LMWeightCompression, + "compression_params": { + "mode": CompressWeightsMode.INT4_ASYM, + "group_size": 64, + }, + "backends": [BackendType.TORCH], + }, { "reported_name": "tinyllama_data_aware_gptq", "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", diff --git a/tests/post_training/test_quantize_conformance.py b/tests/post_training/test_quantize_conformance.py index 5c4fa176ad6..2ea880fde31 100644 --- a/tests/post_training/test_quantize_conformance.py +++ b/tests/post_training/test_quantize_conformance.py @@ -343,7 +343,7 @@ def test_weight_compression( start_time = time.perf_counter() try: if test_case_name not in wc_reference_data: - raise RuntimeError(f"{test_case_name} is not defined in `wc_reference_data` fixture") + pytest.skip(f"{test_case_name} is not defined in `wc_reference_data` fixture") test_model_param = WC_TEST_CASES[test_case_name] maybe_skip_test_case(test_model_param, run_fp32_backend, run_torch_cuda_backend, batch_size) pipeline_cls = test_model_param["pipeline_cls"] diff --git a/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot index c3e5cf0d0c9..7c4f62b7994 100644 --- a/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot +++ b/tests/torch/data/experimental/sparsify_activations/dummy_llama_int8_sym_weights_sparse_activations.dot @@ -1,8 +1,8 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; "1 model.embed_tokens.weight" [id=1, type=nncf_model_const]; -"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; -"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0" [id=3, type=type]; +"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; +"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/type_0" [id=3, type=type]; "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" [id=4, type=embedding]; "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" [id=5, type=to]; "6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0" [id=6, type=pow]; @@ -14,16 +14,16 @@ strict digraph { "12 model.layers.0.input_layernorm.weight" [id=12, type=nncf_model_const]; "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=13, type=__mul__]; "14 model.layers.0.self_attn.q_proj.weight" [id=14, type=nncf_model_const]; -"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=15, type=decompress_symmetric]; -"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" [id=16, type=type]; +"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=15, type=decompress_symmetric]; +"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=16, type=type]; "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=17, type=linear]; "18 model.layers.0.self_attn.k_proj.weight" [id=18, type=nncf_model_const]; -"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric]; -"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" [id=20, type=type]; +"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric]; +"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=20, type=type]; "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=21, type=linear]; "22 model.layers.0.self_attn.v_proj.weight" [id=22, type=nncf_model_const]; -"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=23, type=decompress_symmetric]; -"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" [id=24, type=type]; +"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=23, type=decompress_symmetric]; +"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=24, type=type]; "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=25, type=linear]; "26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" [id=26, type=view]; "27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" [id=27, type=transpose]; @@ -70,8 +70,8 @@ strict digraph { "68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" [id=68, type=contiguous]; "69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" [id=69, type=reshape]; "70 model.layers.0.self_attn.o_proj.weight" [id=70, type=nncf_model_const]; -"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=71, type=decompress_symmetric]; -"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" [id=72, type=type]; +"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=71, type=decompress_symmetric]; +"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=72, type=type]; "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=73, type=linear]; "74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" [id=74, type=__add__]; "75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=75, type=to]; @@ -84,24 +84,24 @@ strict digraph { "82 model.layers.0.post_attention_layernorm.weight" [id=82, type=nncf_model_const]; "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=83, type=__mul__]; "84 model.layers.0.mlp.gate_proj.weight" [id=84, type=nncf_model_const]; -"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=85, type=decompress_symmetric]; -"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" [id=86, type=type]; +"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=85, type=decompress_symmetric]; +"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=86, type=type]; "87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=87, type=abs]; "88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=88, type=le]; "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=89, type=masked_fill]; "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=90, type=linear]; "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=91, type=silu]; "92 model.layers.0.mlp.up_proj.weight" [id=92, type=nncf_model_const]; -"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=93, type=decompress_symmetric]; -"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" [id=94, type=type]; +"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=93, type=decompress_symmetric]; +"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=94, type=type]; "95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=95, type=abs]; "96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=96, type=le]; "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=97, type=masked_fill]; "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=98, type=linear]; "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" [id=99, type=__mul__]; "100 model.layers.0.mlp.down_proj.weight" [id=100, type=nncf_model_const]; -"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=101, type=decompress_symmetric]; -"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" [id=102, type=type]; +"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=101, type=decompress_symmetric]; +"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=102, type=type]; "103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=103, type=abs]; "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=104, type=le]; "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=105, type=masked_fill]; @@ -117,16 +117,16 @@ strict digraph { "115 model.layers.1.input_layernorm.weight" [id=115, type=nncf_model_const]; "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" [id=116, type=__mul__]; "117 model.layers.1.self_attn.q_proj.weight" [id=117, type=nncf_model_const]; -"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=118, type=decompress_symmetric]; -"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" [id=119, type=type]; +"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=118, type=decompress_symmetric]; +"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=119, type=type]; "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" [id=120, type=linear]; "121 model.layers.1.self_attn.k_proj.weight" [id=121, type=nncf_model_const]; -"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=122, type=decompress_symmetric]; -"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" [id=123, type=type]; +"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=122, type=decompress_symmetric]; +"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=123, type=type]; "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" [id=124, type=linear]; "125 model.layers.1.self_attn.v_proj.weight" [id=125, type=nncf_model_const]; -"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=126, type=decompress_symmetric]; -"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" [id=127, type=type]; +"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=126, type=decompress_symmetric]; +"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=127, type=type]; "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" [id=128, type=linear]; "129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" [id=129, type=view]; "130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" [id=130, type=transpose]; @@ -173,8 +173,8 @@ strict digraph { "171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" [id=171, type=contiguous]; "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" [id=172, type=reshape]; "173 model.layers.1.self_attn.o_proj.weight" [id=173, type=nncf_model_const]; -"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=174, type=decompress_symmetric]; -"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" [id=175, type=type]; +"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=174, type=decompress_symmetric]; +"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=175, type=type]; "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" [id=176, type=linear]; "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" [id=177, type=__add__]; "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" [id=178, type=to]; @@ -187,24 +187,24 @@ strict digraph { "185 model.layers.1.post_attention_layernorm.weight" [id=185, type=nncf_model_const]; "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" [id=186, type=__mul__]; "187 model.layers.1.mlp.gate_proj.weight" [id=187, type=nncf_model_const]; -"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=188, type=decompress_symmetric]; -"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" [id=189, type=type]; +"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=188, type=decompress_symmetric]; +"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=189, type=type]; "190 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" [id=190, type=abs]; "191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" [id=191, type=le]; "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" [id=192, type=masked_fill]; "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" [id=193, type=linear]; "194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" [id=194, type=silu]; "195 model.layers.1.mlp.up_proj.weight" [id=195, type=nncf_model_const]; -"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=196, type=decompress_symmetric]; -"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" [id=197, type=type]; +"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=196, type=decompress_symmetric]; +"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=197, type=type]; "198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" [id=198, type=abs]; "199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" [id=199, type=le]; "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" [id=200, type=masked_fill]; "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" [id=201, type=linear]; "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" [id=202, type=__mul__]; "203 model.layers.1.mlp.down_proj.weight" [id=203, type=nncf_model_const]; -"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=204, type=decompress_symmetric]; -"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" [id=205, type=type]; +"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=204, type=decompress_symmetric]; +"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0" [id=205, type=type]; "206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" [id=206, type=abs]; "207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" [id=207, type=le]; "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" [id=208, type=masked_fill]; @@ -220,15 +220,15 @@ strict digraph { "218 model.norm.weight" [id=218, type=nncf_model_const]; "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" [id=219, type=__mul__]; "220 lm_head.weight" [id=220, type=nncf_model_const]; -"221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=221, type=decompress_symmetric]; -"222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0" [id=222, type=type]; +"221 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=221, type=decompress_symmetric]; +"222 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/type_0" [id=222, type=type]; "223 LlamaForCausalLM/Linear[lm_head]/linear_0" [id=223, type=linear]; "224 LlamaForCausalLM/float_0" [id=224, type=float]; "225 /nncf_model_output_0" [id=225, type=nncf_model_output]; "0 /nncf_model_input_0" -> "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0"; -"1 model.embed_tokens.weight" -> "2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0"; -"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/SymmetricWeightsDecompressor/type_0" -> "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0"; +"1 model.embed_tokens.weight" -> "2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"2 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/type_0"; +"3 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/INT8SymmetricWeightsDecompressor/type_0" -> "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0"; "4 LlamaForCausalLM/LlamaModel[model]/Embedding[embed_tokens]/embedding_0" -> "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0"; "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "6 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/pow_0"; "5 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/to_0" -> "10 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___0"; @@ -243,17 +243,17 @@ strict digraph { "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; "13 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; -"14 model.layers.0.self_attn.q_proj.weight" -> "15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0"; -"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; +"14 model.layers.0.self_attn.q_proj.weight" -> "15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"15 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"16 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; "17 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0"; -"18 model.layers.0.self_attn.k_proj.weight" -> "19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0"; -"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; +"18 model.layers.0.self_attn.k_proj.weight" -> "19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"19 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"20 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; "21 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "28 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_1"; -"22 model.layers.0.self_attn.v_proj.weight" -> "23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0"; -"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; +"22 model.layers.0.self_attn.v_proj.weight" -> "23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"23 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"24 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; "25 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "30 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_2"; "26 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/view_0" -> "27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0"; "27 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_0" -> "39 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/__mul___0"; @@ -306,9 +306,9 @@ strict digraph { "67 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/transpose_4" -> "68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0"; "68 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/contiguous_0" -> "69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2"; "69 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/reshape_2" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; -"70 model.layers.0.self_attn.o_proj.weight" -> "71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0"; -"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; +"70 model.layers.0.self_attn.o_proj.weight" -> "71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"71 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"72 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; "73 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0"; "74 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/__add___0" -> "75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0"; "75 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "76 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/pow_0"; @@ -325,26 +325,26 @@ strict digraph { "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0"; "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0"; "83 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0"; -"84 model.layers.0.mlp.gate_proj.weight" -> "85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0"; -"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; +"84 model.layers.0.mlp.gate_proj.weight" -> "85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"85 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"86 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; "87 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0"; "88 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0"; "89 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; "90 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0"; "91 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0"; -"92 model.layers.0.mlp.up_proj.weight" -> "93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0"; -"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; +"92 model.layers.0.mlp.up_proj.weight" -> "93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"93 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"94 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; "95 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0"; "96 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0"; "97 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; "98 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0"; "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0"; "99 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/__mul___0" -> "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0"; -"100 model.layers.0.mlp.down_proj.weight" -> "101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0"; -"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; +"100 model.layers.0.mlp.down_proj.weight" -> "101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"101 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"102 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; "103 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0"; "104 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0"; "105 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "106 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[0]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; @@ -363,17 +363,17 @@ strict digraph { "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; "116 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[input_layernorm]/__mul___1" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; -"117 model.layers.1.self_attn.q_proj.weight" -> "118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0"; -"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/SymmetricWeightsDecompressor/type_0" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; +"117 model.layers.1.self_attn.q_proj.weight" -> "118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"118 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"119 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0"; "120 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[q_proj]/linear_0" -> "129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0"; -"121 model.layers.1.self_attn.k_proj.weight" -> "122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0"; -"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/SymmetricWeightsDecompressor/type_0" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; +"121 model.layers.1.self_attn.k_proj.weight" -> "122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"122 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"123 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0"; "124 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[k_proj]/linear_0" -> "131 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_1"; -"125 model.layers.1.self_attn.v_proj.weight" -> "126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0"; -"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/SymmetricWeightsDecompressor/type_0" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; +"125 model.layers.1.self_attn.v_proj.weight" -> "126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"126 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"127 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0"; "128 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[v_proj]/linear_0" -> "133 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_2"; "129 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/view_0" -> "130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0"; "130 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_0" -> "142 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/__mul___0"; @@ -426,9 +426,9 @@ strict digraph { "170 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/transpose_4" -> "171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0"; "171 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/contiguous_0" -> "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2"; "172 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/reshape_2" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; -"173 model.layers.1.self_attn.o_proj.weight" -> "174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0"; -"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/SymmetricWeightsDecompressor/type_0" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; +"173 model.layers.1.self_attn.o_proj.weight" -> "174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"174 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"175 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0"; "176 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaAttention[self_attn]/Linear[o_proj]/linear_0" -> "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0"; "177 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/__add___0" -> "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0"; "178 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/to_0" -> "179 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/pow_0"; @@ -445,26 +445,26 @@ strict digraph { "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0"; "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0"; "186 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaRMSNorm[post_attention_layernorm]/__mul___1" -> "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0"; -"187 model.layers.1.mlp.gate_proj.weight" -> "188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0"; -"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/SymmetricWeightsDecompressor/type_0" -> "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; +"187 model.layers.1.mlp.gate_proj.weight" -> "188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"188 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"189 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; "190 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/abs_0" -> "191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0"; "191 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/le_0" -> "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0"; "192 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/ActivationsSparsifier/masked_fill_0" -> "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0"; "193 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[gate_proj]/linear_0" -> "194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0"; "194 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/SiLU[act_fn]/silu_0" -> "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0"; -"195 model.layers.1.mlp.up_proj.weight" -> "196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0"; -"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/SymmetricWeightsDecompressor/type_0" -> "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; +"195 model.layers.1.mlp.up_proj.weight" -> "196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"196 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"197 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; "198 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/abs_0" -> "199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0"; "199 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/le_0" -> "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0"; "200 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/ActivationsSparsifier/masked_fill_0" -> "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0"; "201 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[up_proj]/linear_0" -> "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0"; "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0"; "202 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/__mul___0" -> "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0"; -"203 model.layers.1.mlp.down_proj.weight" -> "204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0"; -"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/SymmetricWeightsDecompressor/type_0" -> "209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; +"203 model.layers.1.mlp.down_proj.weight" -> "204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"204 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0"; +"205 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/INT8SymmetricWeightsDecompressor/type_0" -> "209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; "206 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/abs_0" -> "207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0"; "207 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/le_0" -> "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0"; "208 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/ActivationsSparsifier/masked_fill_0" -> "209 LlamaForCausalLM/LlamaModel[model]/ModuleList[layers]/LlamaDecoderLayer[1]/LlamaMLP[mlp]/Linear[down_proj]/linear_0"; @@ -480,9 +480,9 @@ strict digraph { "217 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/to_1" -> "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1"; "218 model.norm.weight" -> "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1"; "219 LlamaForCausalLM/LlamaModel[model]/LlamaRMSNorm[norm]/__mul___1" -> "223 LlamaForCausalLM/Linear[lm_head]/linear_0"; -"220 lm_head.weight" -> "221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"221 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0"; -"222 LlamaForCausalLM/Linear[lm_head]/SymmetricWeightsDecompressor/type_0" -> "223 LlamaForCausalLM/Linear[lm_head]/linear_0"; +"220 lm_head.weight" -> "221 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"221 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "222 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/type_0"; +"222 LlamaForCausalLM/Linear[lm_head]/INT8SymmetricWeightsDecompressor/type_0" -> "223 LlamaForCausalLM/Linear[lm_head]/linear_0"; "223 LlamaForCausalLM/Linear[lm_head]/linear_0" -> "224 LlamaForCausalLM/float_0"; "224 LlamaForCausalLM/float_0" -> "225 /nncf_model_output_0"; } diff --git a/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot index aa24d54a2e0..c433410067b 100644 --- a/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot +++ b/tests/torch/data/experimental/sparsify_activations/linear_int8_sym_weights_sparse_activations.dot @@ -1,8 +1,8 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; "1 weight" [id=1, type=nncf_model_const]; -"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" [id=2, type=decompress_symmetric]; -"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" [id=3, type=type]; +"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" [id=2, type=decompress_symmetric]; +"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" [id=3, type=type]; "4 bias" [id=4, type=nncf_model_const]; "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" [id=5, type=abs]; "6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" [id=6, type=le]; @@ -11,9 +11,9 @@ strict digraph { "9 /nncf_model_output_0" [id=9, type=nncf_model_output]; "0 /nncf_model_input_0" -> "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0"; "0 /nncf_model_input_0" -> "7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0"; -"1 weight" -> "2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0"; -"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" -> "3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0"; -"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" -> "8 Linear/linear_0"; +"1 weight" -> "2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0"; +"2 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/decompress_symmetric_0" -> "3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0"; +"3 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/INT8SymmetricWeightsDecompressor[weights_decompressor_weight]/type_0" -> "8 Linear/linear_0"; "4 bias" -> "8 Linear/linear_0"; "5 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/abs_0" -> "6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0"; "6 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/le_0" -> "7 Linear/NNCFNetworkInterface[_nncf]/ModuleDict[external_op]/ActivationsSparsifier[activations_sparsifier_Linear/linear_0]/masked_fill_0"; diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot index ae3f667ff3a..5d2d15ae64d 100644 --- a/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot +++ b/tests/torch/data/experimental/sparsify_activations/three_linear_ignore1_int8_sym_weights_sparse_activations.dot @@ -1,25 +1,25 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; "1 embedding.weight" [id=1, type=nncf_model_const]; -"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; -"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" [id=3, type=type]; +"2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; +"3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0" [id=3, type=type]; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=4, type=embedding]; "5 linear1.weight" [id=5, type=nncf_model_const]; -"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric]; -"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" [id=7, type=type]; +"6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric]; +"7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0" [id=7, type=type]; "8 linear1.bias" [id=8, type=nncf_model_const]; "9 ThreeLinearModel/Linear[linear1]/linear_0" [id=9, type=linear]; "10 linear3.weight" [id=10, type=nncf_model_const]; -"11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=11, type=decompress_symmetric]; -"12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" [id=12, type=type]; +"11 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=11, type=decompress_symmetric]; +"12 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0" [id=12, type=type]; "13 linear3.bias" [id=13, type=nncf_model_const]; "14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=14, type=abs]; "15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=15, type=le]; "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=16, type=masked_fill]; "17 ThreeLinearModel/Linear[linear3]/linear_0" [id=17, type=linear]; "18 linear2.weight" [id=18, type=nncf_model_const]; -"19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric]; -"20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" [id=20, type=type]; +"19 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=19, type=decompress_symmetric]; +"20 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0" [id=20, type=type]; "21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=21, type=abs]; "22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=22, type=le]; "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=23, type=masked_fill]; @@ -27,29 +27,29 @@ strict digraph { "25 /nncf_model_output_0" [id=25, type=nncf_model_output]; "26 /nncf_model_output_1" [id=26, type=nncf_model_output]; "0 /nncf_model_input_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; -"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0"; -"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; +"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0"; +"3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "9 ThreeLinearModel/Linear[linear1]/linear_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0"; -"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0"; -"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" -> "9 ThreeLinearModel/Linear[linear1]/linear_0"; +"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0"; +"7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0" -> "9 ThreeLinearModel/Linear[linear1]/linear_0"; "8 linear1.bias" -> "9 ThreeLinearModel/Linear[linear1]/linear_0"; "9 ThreeLinearModel/Linear[linear1]/linear_0" -> "14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0"; "9 ThreeLinearModel/Linear[linear1]/linear_0" -> "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0"; -"10 linear3.weight" -> "11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"11 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0"; -"12 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" -> "17 ThreeLinearModel/Linear[linear3]/linear_0"; +"10 linear3.weight" -> "11 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"11 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "12 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0"; +"12 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0" -> "17 ThreeLinearModel/Linear[linear3]/linear_0"; "13 linear3.bias" -> "17 ThreeLinearModel/Linear[linear3]/linear_0"; "14 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0"; "15 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0"; "16 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "17 ThreeLinearModel/Linear[linear3]/linear_0"; "17 ThreeLinearModel/Linear[linear3]/linear_0" -> "25 /nncf_model_output_0"; -"18 linear2.weight" -> "19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"19 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0"; -"20 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" -> "24 ThreeLinearModel/Linear[linear2]/linear_0"; +"18 linear2.weight" -> "19 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"19 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "20 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0"; +"20 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0" -> "24 ThreeLinearModel/Linear[linear2]/linear_0"; "21 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0"; "22 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0"; "23 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "24 ThreeLinearModel/Linear[linear2]/linear_0"; diff --git a/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot b/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot index c6488f1131b..437a8407d37 100644 --- a/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot +++ b/tests/torch/data/experimental/sparsify_activations/three_linear_int8_sym_weights_sparse_activations.dot @@ -1,28 +1,28 @@ strict digraph { "0 /nncf_model_input_0" [id=0, type=nncf_model_input]; "1 embedding.weight" [id=1, type=nncf_model_const]; -"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; -"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" [id=3, type=type]; +"2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=2, type=decompress_symmetric]; +"3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0" [id=3, type=type]; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" [id=4, type=embedding]; "5 linear1.weight" [id=5, type=nncf_model_const]; -"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric]; -"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" [id=7, type=type]; +"6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=6, type=decompress_symmetric]; +"7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0" [id=7, type=type]; "8 linear1.bias" [id=8, type=nncf_model_const]; "9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" [id=9, type=abs]; "10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" [id=10, type=le]; "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" [id=11, type=masked_fill]; "12 ThreeLinearModel/Linear[linear1]/linear_0" [id=12, type=linear]; "13 linear3.weight" [id=13, type=nncf_model_const]; -"14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=14, type=decompress_symmetric]; -"15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" [id=15, type=type]; +"14 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=14, type=decompress_symmetric]; +"15 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0" [id=15, type=type]; "16 linear3.bias" [id=16, type=nncf_model_const]; "17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" [id=17, type=abs]; "18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" [id=18, type=le]; "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" [id=19, type=masked_fill]; "20 ThreeLinearModel/Linear[linear3]/linear_0" [id=20, type=linear]; "21 linear2.weight" [id=21, type=nncf_model_const]; -"22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" [id=22, type=decompress_symmetric]; -"23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" [id=23, type=type]; +"22 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" [id=22, type=decompress_symmetric]; +"23 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0" [id=23, type=type]; "24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" [id=24, type=abs]; "25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" [id=25, type=le]; "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" [id=26, type=masked_fill]; @@ -30,33 +30,33 @@ strict digraph { "28 /nncf_model_output_0" [id=28, type=nncf_model_output]; "29 /nncf_model_output_1" [id=29, type=nncf_model_output]; "0 /nncf_model_input_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; -"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"2 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0"; -"3 ThreeLinearModel/Embedding[embedding]/SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; +"1 embedding.weight" -> "2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"2 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0"; +"3 ThreeLinearModel/Embedding[embedding]/INT8SymmetricWeightsDecompressor/type_0" -> "4 ThreeLinearModel/Embedding[embedding]/embedding_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0"; "4 ThreeLinearModel/Embedding[embedding]/embedding_0" -> "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0"; -"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"6 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0"; -"7 ThreeLinearModel/Linear[linear1]/SymmetricWeightsDecompressor/type_0" -> "12 ThreeLinearModel/Linear[linear1]/linear_0"; +"5 linear1.weight" -> "6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"6 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0"; +"7 ThreeLinearModel/Linear[linear1]/INT8SymmetricWeightsDecompressor/type_0" -> "12 ThreeLinearModel/Linear[linear1]/linear_0"; "8 linear1.bias" -> "12 ThreeLinearModel/Linear[linear1]/linear_0"; "9 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/abs_0" -> "10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0"; "10 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/le_0" -> "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0"; "11 ThreeLinearModel/Linear[linear1]/ActivationsSparsifier/masked_fill_0" -> "12 ThreeLinearModel/Linear[linear1]/linear_0"; "12 ThreeLinearModel/Linear[linear1]/linear_0" -> "17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0"; "12 ThreeLinearModel/Linear[linear1]/linear_0" -> "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0"; -"13 linear3.weight" -> "14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"14 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0"; -"15 ThreeLinearModel/Linear[linear3]/SymmetricWeightsDecompressor/type_0" -> "20 ThreeLinearModel/Linear[linear3]/linear_0"; +"13 linear3.weight" -> "14 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"14 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "15 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0"; +"15 ThreeLinearModel/Linear[linear3]/INT8SymmetricWeightsDecompressor/type_0" -> "20 ThreeLinearModel/Linear[linear3]/linear_0"; "16 linear3.bias" -> "20 ThreeLinearModel/Linear[linear3]/linear_0"; "17 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/abs_0" -> "18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0"; "18 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/le_0" -> "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0"; "19 ThreeLinearModel/Linear[linear3]/ActivationsSparsifier/masked_fill_0" -> "20 ThreeLinearModel/Linear[linear3]/linear_0"; "20 ThreeLinearModel/Linear[linear3]/linear_0" -> "28 /nncf_model_output_0"; -"21 linear2.weight" -> "22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0"; -"22 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/decompress_symmetric_0" -> "23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0"; -"23 ThreeLinearModel/Linear[linear2]/SymmetricWeightsDecompressor/type_0" -> "27 ThreeLinearModel/Linear[linear2]/linear_0"; +"21 linear2.weight" -> "22 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0"; +"22 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/decompress_symmetric_0" -> "23 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0"; +"23 ThreeLinearModel/Linear[linear2]/INT8SymmetricWeightsDecompressor/type_0" -> "27 ThreeLinearModel/Linear[linear2]/linear_0"; "24 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/abs_0" -> "25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0"; "25 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/le_0" -> "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0"; "26 ThreeLinearModel/Linear[linear2]/ActivationsSparsifier/masked_fill_0" -> "27 ThreeLinearModel/Linear[linear2]/linear_0"; diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index fea9e0ce501..519fcfd654e 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -24,6 +24,8 @@ from nncf.quantization import compress_weights from nncf.torch.dynamic_graph.patch_pytorch import disable_patching from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS +from tests.torch.ptq.test_weights_compression import INT4_MODES +from tests.torch.ptq.test_weights_compression import INT8_MODES from tests.torch.ptq.test_weights_compression import SUPPORTED_MODES from tests.torch.ptq.test_weights_compression import UNSUPPORTED_MODES from tests.torch.ptq.test_weights_compression import ConvolutionModel @@ -77,10 +79,13 @@ def _capture_model(model, inputs): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights(mode): - model = ShortTransformer(5, 10) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 n_compressed_weights = 0 n_target_modules = 0 @@ -92,7 +97,7 @@ def test_compress_weights(mode): assert n_target_modules == n_compressed_weights -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) def test_compress_weights_graph_edge(mode): model = ShortTransformer(5, 10) input_ids = torch.randint(0, 10, (5,)) @@ -109,10 +114,13 @@ def test_compress_weights_graph_edge(mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_shared_weights(mocker, mode): with disable_patching(): - model = ShortTransformer(5, 10, share_weights=True) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16, share_weights=True) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 n_compressed_weights = 0 n_target_modules = 0 @@ -142,11 +150,14 @@ def test_compress_weights_shared_weights(mocker, mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compressed_model_inference(mode): torch.manual_seed(42) - model = ShortTransformer(5, 10, share_weights=True) - input_ids = torch.randint(0, 10, (5,)) + model = ShortTransformer(8, 16, share_weights=True) + input_ids = torch.randint(0, 10, (8,)) exported_model = _capture_model(model, input_ids) exported_model_output = exported_model(input_ids) - compressed_model = compress_weights(exported_model, mode=mode) + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(exported_model, mode=mode, **kwargs) compressed_model_outputs = compressed_model(input_ids) assert ( @@ -161,7 +172,7 @@ def test_compress_weights_model_size_conv(mode): dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 model = ConvolutionModel() - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) exported_model = _capture_model(model, input_ids) model_size = get_model_size(exported_model) compressed_model = compress_weights(exported_model, mode=mode) @@ -182,9 +193,11 @@ def test_compress_weights_model_size_conv(mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_functional_model(mode): model = FunctionalModel() - decompressor_type = "symmetric" if mode == CompressWeightsMode.INT8_SYM else "asymmetric" + decompressor_type = ( + "symmetric" if mode in (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT4_SYM) else "asymmetric" + ) - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) exported_model = _capture_model(model, input_ids) compressed_model = compress_weights(exported_model, mode=mode) @@ -196,7 +209,7 @@ def test_compress_weights_functional_model(mode): assert n_compressed_weights == 4 -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) @pytest.mark.parametrize( "params", ( @@ -223,6 +236,27 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): compress_weights(exported_model, mode=mode, **params) +@pytest.mark.parametrize("mode", INT4_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + {"dataset": Dataset([1])}, + ), +) +def test_raise_error_with_unsupported_params_for_int4(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + exported_model = _capture_model(dummy_torch_model, dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(exported_model, mode=mode, **params) + + @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() @@ -252,7 +286,7 @@ def test_model_devices_and_precisions(use_cuda, dtype): model = MatMulModel().to(device) if dtype == torch.float16: model.half() - dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + dummy_input = torch.rand((1, 256), dtype=dtype, device=device) exported_model = _capture_model(model, dummy_input) compressed_model = compress_weights(exported_model) result = compressed_model(dummy_input) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index f982a8375d1..88373ac308f 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -19,8 +19,14 @@ from nncf import SensitivityMetric from nncf.quantization import compress_weights from nncf.torch import wrap_model -from nncf.torch.quantization.layers import AsymmetricWeightsDecompressor -from nncf.torch.quantization.layers import SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8AsymmetricWeightsDecompressor +from nncf.torch.quantization.layers import INT8SymmetricWeightsDecompressor +from nncf.torch.quantization.quantize_functions import pack_int4 +from nncf.torch.quantization.quantize_functions import pack_uint4 +from nncf.torch.quantization.quantize_functions import unpack_int4 +from nncf.torch.quantization.quantize_functions import unpack_uint4 DATA_BASED_SENSITIVITY_METRICS = ( SensitivityMetric.HESSIAN_INPUT_ACTIVATION, @@ -31,12 +37,10 @@ ALL_SENSITIVITY_METRICS = DATA_BASED_SENSITIVITY_METRICS + (SensitivityMetric.WEIGHT_QUANTIZATION_ERROR,) -SUPPORTED_MODES = (CompressWeightsMode.INT8, CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) -UNSUPPORTED_MODES = ( - CompressWeightsMode.INT4_SYM, - CompressWeightsMode.INT4_ASYM, - CompressWeightsMode.NF4, -) +INT8_MODES = (CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM) +INT4_MODES = (CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM) +SUPPORTED_MODES = INT8_MODES + INT4_MODES +UNSUPPORTED_MODES = (CompressWeightsMode.NF4, CompressWeightsMode.E2M1) class ShortTransformer(torch.nn.Module): @@ -59,7 +63,7 @@ def forward(self, input_ids): class MatMulModel(torch.nn.Module): def __init__(self): super().__init__() - self.w = torch.nn.Parameter(torch.ones(size=(300, 300), dtype=torch.float32)) + self.w = torch.nn.Parameter(torch.ones(size=(256, 256), dtype=torch.float32)) def forward(self, input): return input @ self.w @@ -69,7 +73,7 @@ class FunctionalModel(torch.nn.Module): def __init__(self): super().__init__() self.conv_w = torch.nn.Parameter(torch.ones(size=(5, 3, 3, 3), dtype=torch.float32)) - self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 300, 300), dtype=torch.float32)) + self.matmul_w = torch.nn.Parameter(torch.ones(size=(1, 3, 256, 256), dtype=torch.float32)) self.conv_tr_w = torch.nn.Parameter(torch.rand(size=(5, 4, 3, 3))) self.nested_matmul = MatMulModel() @@ -109,14 +113,18 @@ def forward(self, input_): return x -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights(mode): - model = ShortTransformer(5, 10) + model = ShortTransformer(8, 16) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 - input_ids = torch.randint(0, 10, (5,)) + input_ids = torch.randint(0, 10, (8,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model, mode=mode) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) n_compressed_weights = 0 n_target_modules = 0 @@ -130,14 +138,19 @@ def test_compress_weights(mode): assert n_compressed_weights == n_target_modules -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_functional_model(mode): model = FunctionalModel() - decompressor_type = ( - SymmetricWeightsDecompressor if mode == CompressWeightsMode.INT8_SYM else AsymmetricWeightsDecompressor - ) + decompressor_map = { + CompressWeightsMode.INT8_SYM: (INT8SymmetricWeightsDecompressor,), + CompressWeightsMode.INT8_ASYM: (INT8AsymmetricWeightsDecompressor,), + CompressWeightsMode.INT4_SYM: (INT4SymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + CompressWeightsMode.INT4_ASYM: (INT4AsymmetricWeightsDecompressor, INT8AsymmetricWeightsDecompressor), + } - input_ids = torch.randint(0, 10, [1, 3, 300, 300]) + decompressor_type = decompressor_map[mode] + + input_ids = torch.randint(0, 10, [1, 3, 256, 256]) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) compressed_model = compress_weights(wrapped_model, mode=mode) @@ -167,14 +180,18 @@ def test_compress_weights_conv(): assert n_compressed_weights == n_target_modules -@pytest.mark.parametrize("mode", (CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM)) +@pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_shared_weights(mocker, mode): - model = ShortTransformer(5, 10, share_weights=True) + model = ShortTransformer(8, 16, share_weights=True) dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 - input_ids = torch.randint(0, 10, (5,)) + input_ids = torch.randint(0, 10, (8,)) wrapped_model = wrap_model(model, example_input=input_ids, trace_parameters=True) - compressed_model = compress_weights(wrapped_model, mode=mode) + + kwargs = {} + if mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]: + kwargs["group_size"] = 4 + compressed_model = compress_weights(wrapped_model, mode=mode, **kwargs) n_compressed_weights = 0 n_target_modules = 0 @@ -203,7 +220,7 @@ def forward(self, input): return input -@pytest.mark.parametrize("mode", SUPPORTED_MODES) +@pytest.mark.parametrize("mode", INT8_MODES) @pytest.mark.parametrize( "params", ( @@ -229,6 +246,26 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): compress_weights(wrapped_model, mode=mode, **params) +@pytest.mark.parametrize("mode", INT4_MODES) +@pytest.mark.parametrize( + "params", + ( + {"ratio": 0.5}, + *({"sensitivity_metric": metric} for metric in ALL_SENSITIVITY_METRICS), + {"gptq": True}, + {"awq": True}, + {"scale_estimation": True}, + {"lora_correction": True}, + ), +) +def test_raise_error_with_unsupported_params_for_int4(mode, params): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, mode=mode, **params) + + @pytest.mark.parametrize("mode", UNSUPPORTED_MODES) def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() @@ -270,7 +307,7 @@ def test_model_devices_and_precisions(use_cuda, dtype): if dtype == torch.float16: model.half() - dummy_input = torch.rand((1, 300), dtype=dtype, device=device) + dummy_input = torch.rand((1, 256), dtype=dtype, device=device) wrapped_model = wrap_model(model, example_input=dummy_input, trace_parameters=True) compressed_model = compress_weights(wrapped_model) result = compressed_model(dummy_input) @@ -279,3 +316,21 @@ def test_model_devices_and_precisions(use_cuda, dtype): assert compressed_model.state_dict()["_nncf.external_op.weights_decompressor_w._scale"].dtype == torch.float16 # Result should be in the precision of the model assert result.dtype == dtype + + +def test_pack_uint4(): + w_uint8 = torch.randint(0, 15, (4, 4), dtype=torch.uint8) + packed_w = pack_uint4(w_uint8) + assert packed_w.dtype == torch.uint8 + assert packed_w.numel() * 2 == w_uint8.numel() + unpacked_w = unpack_uint4(packed_w).reshape(w_uint8.shape) + assert torch.all(unpacked_w == w_uint8) + + +def test_pack_int4(): + w_int8 = torch.randint(-8, 7, (4, 4), dtype=torch.int8) + packed_w = pack_int4(w_int8) + assert packed_w.dtype == torch.uint8 + assert packed_w.numel() * 2 == w_int8.numel() + unpacked_w = unpack_int4(packed_w).reshape(w_int8.shape) + assert torch.all(unpacked_w == w_int8)