Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Aug 18, 2023
1 parent cf47cbd commit 9cc7582
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
4 changes: 2 additions & 2 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,10 @@ def mean_per_channel(x: Tensor, axis: int) -> Tensor:
:return: Reduced Tensor.
"""
if len(x.shape) < 3:
return mean(x.data, axis=0)
return Tensor(mean(x.data, axis=0))
x = moveaxis(x.data, axis, 1)
t = x.reshape([x.shape[0], x.shape[1], -1])
return mean(t, axis=(0, 2))
return Tensor(mean(t, axis=(0, 2)))


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion nncf/experimental/tensor/numpy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] =

@registry_numpy_types(functions.count_nonzero)
def _(a: Union[np.ndarray, np.number], axis: Optional[Union[int, Tuple[int]]] = None) -> np.ndarray:
return np.count_nonzero(a, axis=axis)
return np.array(np.count_nonzero(a, axis=axis))


@registry_numpy_types(functions.isempty)
Expand Down
13 changes: 9 additions & 4 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,11 @@ def test_iter(self):
"axis, ref",
(
(None, 3),
(0, [2, 1]),
(0, [2.0, 1.0]),
),
)
def test_fn_count_nonzero(self, axis, ref):
tensor = self.to_tensor([[1, 2], [1, 0]])
tensor = self.to_tensor([[1.0, 2.0], [1.0, 0.0]])
nncf_tensor = Tensor(tensor)
ref_tensor = self.to_tensor(ref)
res = functions.count_nonzero(nncf_tensor, axis=axis)
Expand Down Expand Up @@ -661,7 +661,11 @@ def test_fn_round(self, val, decimals, ref):
@pytest.mark.parametrize(
"val, axis, ref",
(
(1.1, 0, 1.1),
(
[[9.0, 9.0], [7.0, 1.0]],
0,
[8.0, 5.0],
),
(
[[[9.0, 9.0], [0.0, 3.0]], [[5.0, 1.0], [7.0, 1.0]]],
0,
Expand Down Expand Up @@ -702,4 +706,5 @@ def test_fn_mean_per_channel(self, val, axis, ref):
tensor = Tensor(self.to_tensor(val))
ref_tensor = self.to_tensor(ref)
res = functions.mean_per_channel(tensor, axis)
assert functions.allclose(res.data, ref_tensor), f"{res.data}"
assert isinstance(res, Tensor)
assert functions.allclose(res, ref_tensor), f"{res.data}"

0 comments on commit 9cc7582

Please sign in to comment.