Skip to content

Commit

Permalink
tqdm progress bar improvements (openvinotoolkit#2114)
Browse files Browse the repository at this point in the history
### Changes

1. Fixed an issue with wrong `tqdm` bar length in the case when
calibration dataset length is less than `subset_size`.
Reproducer:
nikita-savelyevv@f0951c1
**Before:**
`Statistics collection: 34%|██████ | 101/300 [00:03<00:06, 28.66it/s]`
**After:**
When dataset has `__len__`:
`Statistics collection: 100%|██████████████████| 101/101 [00:03<00:00,
28.20it/s]`
When dataset doesn't have `__len__`:
`Statistics collection: 34%|██████ | 101/300 [00:03<00:06, 29.45it/s]`

2. Improved progress bar GUI when ran from notebooks.
**Before:**
<img width="704" alt="Screenshot 2023-09-06 091857"
src="https://github.com/openvinotoolkit/nncf/assets/23343961/9851cb8d-00f1-4297-af50-14697e86e961">

or (in some browsers progress bar takes up multiple lines):


![image](https://github.com/openvinotoolkit/nncf/assets/23343961/99fa9629-2869-4d8f-872e-97ef59bc092e)
**After:**
<img width="706" alt="Screenshot 2023-09-06 105453"
src="https://github.com/openvinotoolkit/nncf/assets/23343961/58e75cc9-2507-4c5b-8c3c-cac44eefcb79">

In console the progress bar is the same.

### Reason for changes

User experience improvement.

### Related tickets

112627

### Tests

<!--- How was the correctness of changes tested and whether new tests
were added -->
  • Loading branch information
nikita-savelyevv authored Sep 12, 2023
1 parent 8400793 commit 0a08966
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 6 deletions.
10 changes: 8 additions & 2 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from itertools import islice
from typing import Any, Dict, TypeVar

from tqdm import tqdm
from tqdm.auto import tqdm

from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
Expand Down Expand Up @@ -54,9 +54,15 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

dataset_length = self.dataset.get_length()
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)
for input_data in tqdm(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=self.stat_subset_size,
total=total,
desc="Statistics collection",
):
outputs = engine.infer(input_data)
Expand Down
9 changes: 9 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Mo
"""
return DataProvider(self._data_source, self._transform_func, indices)

def get_length(self) -> Optional[int]:
"""
Tries to fetch length of the underlying dataset.
:return: The length of the data_source if __len__() is implemented for it, and None otherwise.
"""
if hasattr(self._data_source, "__len__"):
return self._data_source.__len__()
return None


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf import nncf_logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import CommandCreatorFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union

from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import EngineFactory
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, TypeVar

from tqdm import tqdm
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
Expand Down
63 changes: 63 additions & 0 deletions tests/common/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nncf import Dataset


def test_dataset():
raw_data = list(range(50))
dataset = Dataset(raw_data)

data_provider = dataset.get_data()
retrieved_data_items = list(data_provider)
assert all(raw_data[i] == retrieved_data_items[i] for i in range(len(raw_data)))


def test_dataset_with_transform_func():
raw_data = list(range(50))
dataset = Dataset(raw_data, transform_func=lambda it: 2 * it)

data_provider = dataset.get_inference_data()
retrieved_data_items = list(data_provider)
assert all(2 * raw_data[i] == retrieved_data_items[i] for i in range(len(raw_data)))


def test_dataset_with_indices():
raw_data = list(range(50))
dataset = Dataset(raw_data)

data_provider = dataset.get_data(indices=list(range(0, 50, 2)))
retrieved_data_items = list(data_provider)
assert all(raw_data[2 * i] == retrieved_data_items[i] for i in range(len(raw_data) // 2))


def test_dataset_with_transform_func_with_indices():
raw_data = list(range(50))
dataset = Dataset(raw_data, transform_func=lambda it: 2 * it)

data_provider = dataset.get_inference_data(indices=list(range(0, 50, 2)))
retrieved_data_items = list(data_provider)
assert all(2 * raw_data[2 * i] == retrieved_data_items[i] for i in range(len(raw_data) // 2))


def test_dataset_without_length():
raw_data = list(range(50))
dataset_with_length = Dataset(raw_data)
dataset_without_length = Dataset(iter(raw_data))
assert dataset_with_length.get_length() == 50
assert dataset_without_length.get_length() is None

data_provider = dataset_with_length.get_data()
retrieved_data_items = list(data_provider)
assert all(raw_data[i] == retrieved_data_items[i] for i in range(len(raw_data)))

data_provider = dataset_without_length.get_data()
retrieved_data_items = list(data_provider)
assert all(raw_data[i] == retrieved_data_items[i] for i in range(len(raw_data)))

0 comments on commit 0a08966

Please sign in to comment.