Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr/2197 #17

Open
wants to merge 101 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1ce23c5
draft
kshpv Oct 13, 2023
8184df6
check on Nones
kshpv Oct 13, 2023
2e3f507
update aggregator with keep_dims=True
kshpv Oct 18, 2023
b5d15cd
typhints
kshpv Oct 18, 2023
8cb2391
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Oct 18, 2023
1034acd
fix OV tests; update collectors
kshpv Oct 19, 2023
8b526c5
fix tests
kshpv Oct 20, 2023
e51bdb8
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Nov 6, 2023
37684bd
add aggregation axes for OV; comment input check
kshpv Nov 7, 2023
18d931d
add test for OV and Torch
kshpv Nov 8, 2023
605a325
add batch_size param to conformance test
kshpv Nov 9, 2023
fb16b99
hardcode for CI run
kshpv Nov 9, 2023
cd60fa3
hardcode batch size = 10 for calibrate.py
kshpv Nov 10, 2023
f3bda28
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Dec 18, 2023
cc621ab
merge
kshpv Dec 18, 2023
d2a9b00
update aggregator
kshpv Dec 20, 2023
5ffdf10
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Dec 20, 2023
d95be5d
revert unneseccary changes
kshpv Dec 20, 2023
cd68684
add logging; add torch data for OVEngine
kshpv Dec 20, 2023
4a009f3
refactor method get axes
kshpv Dec 21, 2023
c2659b3
fix OV tests
kshpv Dec 21, 2023
3a13f00
fix Torch tests
kshpv Jan 4, 2024
2347170
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 15, 2024
880073b
logic of warning message inside StatisticsAggregator
kshpv Jan 15, 2024
e9062a5
remove _check_input_data_format in OVEngine
kshpv Jan 15, 2024
8770ca4
get_channel_agnostic_reduction_axes to common
kshpv Jan 15, 2024
9556c49
use get_channel_agnostic_reduction_axes for Torch
kshpv Jan 15, 2024
cb90e77
use get_channel_agnostic_reduction_axes for ONNX
kshpv Jan 15, 2024
d9167e5
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
cd10c57
draft
kshpv Jan 17, 2024
16cc9db
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
11b538a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
426ec04
fix test
kshpv Jan 18, 2024
21b0963
align reduction shape and aggregation shape
kshpv Jan 18, 2024
e90ca32
get_channel_agnostic_reduction_axes -> get_reduction_axes
kshpv Jan 18, 2024
f078a78
upd get_reduction_aggregation_axes
kshpv Jan 18, 2024
e4c57cd
upd aggregator
kshpv Jan 18, 2024
d226074
fix OV test
kshpv Jan 18, 2024
7d8ecd4
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 18, 2024
f502de5
fix ONNX test
kshpv Jan 18, 2024
83d03cb
tests
kshpv Jan 18, 2024
fbfe587
fix torch tests
kshpv Jan 18, 2024
0ae6ac4
fix tests
kshpv Jan 18, 2024
496339f
common tests
kshpv Jan 18, 2024
bcce584
add docs
kshpv Jan 18, 2024
e5950e0
comment
kshpv Jan 18, 2024
1d9ac7a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 19, 2024
41f27b5
rollback changes for torch possible impact qat
kshpv Jan 19, 2024
51f3dd9
upd conformance
kshpv Jan 19, 2024
3a8de2f
upd calibrate.py
kshpv Jan 19, 2024
946523d
add get_reduction_aggregation_axes for PTRangeInitCollectorParams
kshpv Jan 19, 2024
1732d70
non returning None for get_reduction_aggregation_axes
kshpv Jan 19, 2024
1e96318
comments
kshpv Jan 19, 2024
03afe91
comments
kshpv Jan 19, 2024
bf792fb
describe comment
kshpv Jan 19, 2024
f98aea2
description x2
kshpv Jan 19, 2024
fbd05f9
description x3
kshpv Jan 19, 2024
e80bab1
apply suggestion
kshpv Jan 23, 2024
9c1648d
comments
kshpv Jan 24, 2024
df8ad03
add default scenario when batch_size=1 or None
kshpv Jan 25, 2024
f4db2bb
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 25, 2024
f4dfd1c
rollback scales changes
kshpv Jan 26, 2024
4a44a1c
fix tests
kshpv Jan 26, 2024
d4bfaca
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 26, 2024
f77f59b
fix OV test
kshpv Jan 26, 2024
43fd729
add warning for model_type=transformer
kshpv Jan 29, 2024
c20f7d3
fix torch test
kshpv Jan 29, 2024
52203f0
fix torch tests
kshpv Jan 29, 2024
9dd02b9
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 29, 2024
48c8426
final fix torch test
kshpv Jan 30, 2024
3fe8a37
comments
kshpv Jan 30, 2024
d228589
comments x2
kshpv Jan 30, 2024
b7de564
comments x3
kshpv Jan 30, 2024
67e4c7d
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
489d603
fix tests after merge
kshpv Jan 30, 2024
3b9fb6f
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
120ee1a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 31, 2024
1f0cb94
improve test
kshpv Jan 31, 2024
532e8eb
fix test
kshpv Feb 6, 2024
38d71b8
upd fbs method calculations
kshpv Feb 6, 2024
c490362
revert changes with statistics collection
kshpv Feb 7, 2024
b778c0c
updates aggregators, reducers for BC and FBC
kshpv Feb 13, 2024
1a96012
upd torch mean_per_channel
kshpv Feb 14, 2024
2f89913
fix BC
kshpv Feb 14, 2024
f69acbd
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 14, 2024
74594c7
fixes after merge
kshpv Feb 14, 2024
d760caf
Fix BC calculations
kshpv Feb 15, 2024
50ac6b4
revert FBC and BC changes
kshpv Feb 20, 2024
00f7979
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 20, 2024
532ca55
fix merge
kshpv Feb 20, 2024
54f8ca3
fix revert typo
kshpv Feb 20, 2024
7637d4b
fix export of torch model
kshpv Feb 21, 2024
976255f
comments
kshpv Feb 23, 2024
8951e3c
more comments
kshpv Feb 26, 2024
0d72557
make bs=128 for Torch sample
kshpv Feb 26, 2024
0f8a438
fix channel alighnment + comments
kshpv Feb 27, 2024
78d4d6c
comments
kshpv Feb 28, 2024
34c9960
update typehints; revert changes in OV sample and apply to Torch
kshpv Feb 28, 2024
354505a
typo
kshpv Feb 28, 2024
97cb07f
some code improvements
kshpv Feb 28, 2024
2cc8b81
logging
kshpv Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def get_model_size(ir_path: str, m_type: str = "Mb", verbose: bool = True) -> fl
]
),
)
val_data_loader = torch.utils.data.DataLoader(val_dataset)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128)
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

torch_model = models.mobilenet_v2(num_classes=DATASET_CLASSES)
torch_model = load_checkpoint(torch_model)
Expand Down
16 changes: 15 additions & 1 deletion nncf/common/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import List, Set
from typing import List, Set, Tuple, Union

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
Expand Down Expand Up @@ -114,3 +114,17 @@ def get_number_of_quantized_ops(
else:
nodes_to_see.extend(graph.get_next_nodes(node))
return len(quantized_ops)


def get_reduction_axes(channel_axes: Union[List[int], Tuple[int]], shape: Union[List[int], Tuple[int]]) -> Tuple[int]:
"""
Returns filtered reduction axes without axes that corresponds channels.

:param channel_axes: Channel axes.
:param shape: Shape that need to be filtered.
:return: Reduction axes.
"""
reduction_axes = list(range(len(shape)))
for channel_axis in sorted(channel_axes, reverse=True):
del reduction_axes[channel_axis]
return tuple(reduction_axes)
51 changes: 50 additions & 1 deletion nncf/common/quantization/initialization/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union

from nncf.common.graph.utils import get_reduction_axes
from nncf.common.initialization.dataloader import NNCFDataLoader
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.quantization.structs import QuantizerGroup
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.config.schemata.defaults import NUM_INIT_SAMPLES
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes


class RangeInitConfig:
Expand Down Expand Up @@ -204,3 +207,49 @@ def use_means_of_mins(self) -> bool:
@property
def use_means_of_maxs(self) -> bool:
return not self._is_weights and not self._is_per_channel

def _get_reduction_axes(
self,
shape_to_reduce: Union[Tuple[int], List[int]],
quantization_axes: Union[Tuple[int], List[int]],
aggregation_axes: Union[Tuple[int], List[int]],
):
"""
Returns axes for a reducer regarding aggregation axes. As aggregator takes axes counting from stacked tensors,
from these axes only tensor related axes should be used for reducer.

:param shape_to_reduce: Shape of a reduced tensor.
:param quantization_axes: Axes of quantization.
:param aggregation_axes: Axes of aggregator which is applied onto reduced tensor.
:return: Axes for reducer.
"""
axes_to_keep = set(el - 1 for el in aggregation_axes if el != 0)
axes_to_keep.update(quantization_axes)
return get_reduction_axes(axes_to_keep, shape_to_reduce)

def _get_aggregation_axes(self, is_per_sample: bool) -> Tuple[int]:
"""
Returns axes for aggregator.

:param is_per_sample: Whether to aggreagate tensor statistics per batch axis.
:return Tuple[int]: Aggregation axes.
"""
return (0, 1) if is_per_sample else (0,)

def get_reduction_aggregation_axes(
self,
shape_to_reduce: Union[Tuple[int], List[int]],
quantization_axes: Union[Tuple[int], List[int]],
is_per_sample: bool,
) -> Tuple[ReductionAxes, AggregationAxes]:
"""
Calculates the reduction axes, aggregation axes for the tensor.

:param shape_to_reduce: Shape of the tensor.
:param quantization_axes: Quantization axes if per-channel quantization.
:param is_per_sample: Whether to calculate statistics per-sample (aggregate batch axis)
:return: Reduction axes and aggregation axes.
"""
aggregation_axes = self._get_aggregation_axes(is_per_sample)
reduction_axes = self._get_reduction_axes(shape_to_reduce, quantization_axes, aggregation_axes)
return reduction_axes, aggregation_axes
110 changes: 90 additions & 20 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
from abc import ABC
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, TypeVar
from typing import Any, Dict, List, Optional, TypeVar

import nncf
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.logger import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
Expand All @@ -25,6 +27,23 @@
TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")

EMPTY_DATASET_MESSAGE = (
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
BATCH_SIZE_IS_BIGGER_THAN_SUBSET_SIZE_MESSAGE = (
"Provided dataset has a batch size value is bigger than subset size for statistics collection. "
"Please increase the number of samples for a statistics collection "
"or decrease the batch size value in the dataset."
)
BATCH_SIZE_MODEL_WARNING = (
"For the particular model the batch size > 1 can lead to inaccurate collected statistics. "
"The recomendation is to provide dataloader instance with the batch_size = 1."
)
DECREASING_SAMPLES_NUMBER_MESSAGE = (
"The number of samples for statistics collection is decreased "
"to align with the provided batch size value of the dataset."
)


class StatisticsAggregator(ABC):
"""
Expand All @@ -34,8 +53,38 @@ class StatisticsAggregator(ABC):
def __init__(self, dataset: Dataset):
self.dataset = dataset
self.stat_subset_size = None
self.batch_size = self.dataset.get_batch_size() or 1
dataset_len = self.dataset.get_length()
self.dataset_sample_size = (
dataset_len * self.batch_size if dataset_len is not None else dataset_len
) # Number of samples in the dataset
if self.dataset_sample_size == 0:
raise nncf.ValidationError(EMPTY_DATASET_MESSAGE)
self.statistic_points = StatisticPointsContainer()

def _get_number_samples_for_statistics(
self,
) -> Optional[int]:
"""
Returns number of samples for statistics collection.

:return: Number of samples for statistics collection.
"""
return (
min(self.dataset_sample_size or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)

def _get_iterations_num(self, total_statistics_samples: int) -> int:
"""
Returns number of iterations to collect statistics.

:param total_statistics_samples: Number of statistics samples are used.
:return: Iterations number of statistics collection.
"""
return total_statistics_samples // self.batch_size

def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
Collects statistics for registered StatisticPoints.
Expand All @@ -46,34 +95,35 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
if not self.statistic_points:
return

if self.batch_size > 1 and self.is_model_has_no_batch_axis(graph):
nncf_logger.warning(BATCH_SIZE_MODEL_WARNING)
model_transformer = factory.ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

dataset_length = self.dataset.get_length()
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
statistics_samples_num = self._get_number_samples_for_statistics()
iterations_num = (
self._get_iterations_num(statistics_samples_num) if statistics_samples_num is not None else None
)
if iterations_num is not None:
if iterations_num == 0:
raise nncf.ValidationError(BATCH_SIZE_IS_BIGGER_THAN_SUBSET_SIZE_MESSAGE)
samples_num = iterations_num * self.batch_size
if samples_num != statistics_samples_num:
nncf_logger.warning(DECREASING_SAMPLES_NUMBER_MESSAGE)
statistics_samples_num = samples_num
empty_statistics = True
for input_data in track(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=total,
description="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
empty_statistics = False
with track(total=statistics_samples_num, description="Statistics collection") as pbar:
for input_data in islice(self.dataset.get_inference_data(), iterations_num):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
pbar.progress.update(pbar.task, advance=self.batch_size)
empty_statistics = False
if empty_statistics:
raise nncf.ValidationError(
"Calibration dataset must not be empty. Please provide calibration dataset with at least one sample."
)
raise nncf.ValidationError(EMPTY_DATASET_MESSAGE)

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Expand All @@ -95,6 +145,26 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer)
elif tensor_collector.num_samples is not None:
self.stat_subset_size = max(self.stat_subset_size, tensor_collector.num_samples)

def is_model_has_no_batch_axis(self, graph: NNCFGraph) -> bool:
"""
Returns True if NNCFGraph contains metatypes with no batch axis in output tensor.

:param graph: NNCFGraph.
:return: True if NNCFGraph contains metatypes with no batch axis in output tensor.
"""
unique_graph_metatypes = set(node.metatype for node in graph.get_all_nodes())
return any(metatype in self.metatypes_no_batch_support for metatype in unique_graph_metatypes)

@property
@abstractmethod
def metatypes_no_batch_support(self) -> List[OperatorMetatype]:
"""
These metatypes mix outputs for different samples into one axis.
If reducers and aggregators collect statistics at the output of the following operations,
assuming that 0-axis is batch axis, they get only 1 value instead of batch_size values.
It could lead to inaccurate/incorrect statistics result.
"""

@abstractmethod
def _register_statistics(self, outputs: Dict[str, NNCFTensor], statistic_points: StatisticPointsContainer) -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def get_length(self) -> Optional[int]:
return self._data_source.__len__()
return None

def get_batch_size(self) -> Optional[int]:
"""
Tries to fetch batch size of the underlying dataset.
:return: The value of batch_size or _batch_size attributes of the data_source if exist, and None otherwise.
"""
if hasattr(self._data_source, "batch_size"): # Torch dataloader
return self._data_source.batch_size
if hasattr(self._data_source, "_batch_size"): # TF dataloader
return self._data_source._batch_size
return None


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor import TensorDataType
from nncf.experimental.tensor.functions import numeric as fns


def mean_per_channel(x: Tensor, axis: int) -> Tensor:
def mean_per_channel(x: Tensor, axis: int, dtype: Optional[TensorDataType] = None) -> Tensor:
"""
Computes the mean of elements across given channel dimension of Tensor.

:param x: Tensor to reduce.
:param axis: The channel dimensions to reduce.
:param dtype: Type to use in computing the mean.
:return: Reduced Tensor.
"""
if len(x.shape) < 3:
return fns.mean(x, axis=0)
return fns.mean(x, axis=0, dtype=dtype)

pos_axis = axis + x.ndim if axis < 0 else axis
if pos_axis < 0 or pos_axis >= x.ndim:
raise ValueError(f"axis {axis} is out of bounds for array of dimension {x.ndim}")
axis = tuple(i for i in range(x.ndim) if i != pos_axis)
return fns.mean(x, axis=axis)
return fns.mean(x, axis=axis, dtype=dtype)
7 changes: 5 additions & 2 deletions nncf/experimental/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,16 +355,19 @@ def moveaxis(a: Tensor, source: Union[int, Tuple[int, ...]], destination: Union[

@functools.singledispatch
@tensor_guard
def mean(a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Tensor:
def mean(
a: Tensor, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, dtype: TensorDataType = None
) -> Tensor:
"""
Compute the arithmetic mean along the specified axis.

:param a: Array containing numbers whose mean is desired.
:param axis: Axis or axes along which the means are computed.
:param keepdims: Destination positions for each of the original axes. These must also be unique.
:param dtype: Type to use in computing the mean.
:return: Array with moved axes.
"""
return Tensor(mean(a.data, axis, keepdims))
return Tensor(mean(a.data, axis, keepdims, dtype))
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved


@functools.singledispatch
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,14 @@ def _(a: np.ndarray, source: Union[int, Tuple[int, ...]], destination: Union[int


@register_numpy_types(numeric.mean)
def _(a: Union[np.ndarray, np.generic], axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> np.ndarray:
return np.array(np.mean(a, axis=axis, keepdims=keepdims))
def _(
a: Union[np.ndarray, np.generic],
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> np.ndarray:
dtype = DTYPE_MAP[dtype] if dtype else None
return np.array(np.mean(a, axis=axis, keepdims=keepdims, dtype=dtype))


@register_numpy_types(numeric.round)
Expand Down
10 changes: 8 additions & 2 deletions nncf/experimental/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,14 @@ def _(a: torch.Tensor, source: Union[int, Tuple[int, ...]], destination: Union[i


@numeric.mean.register(torch.Tensor)
def _(a: torch.Tensor, axis: Union[int, Tuple[int, ...]] = None, keepdims: bool = False) -> torch.Tensor:
return torch.mean(a, dim=axis, keepdim=keepdims)
def _(
a: torch.Tensor,
axis: Union[int, Tuple[int, ...]] = None,
keepdims: bool = False,
dtype: Optional[TensorDataType] = None,
) -> torch.Tensor:
dtype = DTYPE_MAP[dtype] if dtype else None
return torch.mean(a, dim=axis, keepdim=keepdims, dtype=dtype)


@numeric.round.register(torch.Tensor)
Expand Down
9 changes: 9 additions & 0 deletions nncf/onnx/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,12 @@
onnx_metatypes.ONNXQuantizeLinearMetatype,
onnx_metatypes.ONNXDequantizeLinearMetatype,
]

# These metatypes mix outputs for different samples into one axis.
# If reducers and aggregators collect statistics at the output of the following operations,
# assuming that 0-axis is batch axis, they get only 1 value instead of batch_size values.
# It could lead to inaccurate/incorrect statistics result.
OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS = [
onnx_metatypes.ONNXROIAlignMetatype,
onnx_metatypes.ONNXEmbeddingMetatype,
]
4 changes: 2 additions & 2 deletions nncf/onnx/graph/metatypes/onnx_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ class ONNXScatterNDMetatype(ONNXOpMetatype):


@ONNX_OPERATION_METATYPES.register()
class ONNXRoiAlignMetatype(ONNXOpMetatype):
name = "RoiAlignOp"
class ONNXROIAlignMetatype(ONNXOpMetatype):
name = "ROIAlignOp"
op_names = ["RoiAlign"]


Expand Down
Loading
Loading