Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added rich-based progress bar #2132

Merged
merged 11 commits into from
Sep 18, 2023
140 changes: 140 additions & 0 deletions nncf/common/logging/track_progress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Iterable, List, Optional, Sequence, Union

from rich.console import Console
from rich.progress import BarColumn
from rich.progress import Column
from rich.progress import Progress
from rich.progress import ProgressColumn
from rich.progress import ProgressType
from rich.progress import Task
from rich.progress import TaskProgressColumn
from rich.progress import TextColumn
from rich.progress import TimeElapsedColumn
from rich.progress import TimeRemainingColumn
from rich.style import StyleType
from rich.text import Text


class IterationsColumn(ProgressColumn):
def render(self, task: Task) -> Text:
if task.total is None:
return Text("")
text = f"{int(task.completed)}/{int(task.total)}"
if task.finished:
return Text(text, style="progress.elapsed")
return Text(text, style="progress.remaining")


class SeparatorColumn(ProgressColumn):
def __init__(self, table_column: Optional[Column] = None, disable_if_no_total: bool = False) -> None:
super().__init__(table_column)
self.disable_if_no_total = disable_if_no_total

def render(self, task: Task) -> Text:
if self.disable_if_no_total and task.total is None:
return Text("")
return Text("•")


class track:
def __init__(
self,
sequence: Optional[Union[Sequence[ProgressType], Iterable[ProgressType]]] = None,
description: str = "Working...",
total: Optional[float] = None,
auto_refresh: bool = True,
console: Optional[Console] = None,
transient: bool = False,
get_time: Optional[Callable[[], float]] = None,
refresh_per_second: float = 10,
style: StyleType = "bar.back",
complete_style: StyleType = "bar.complete",
finished_style: StyleType = "bar.finished",
pulse_style: StyleType = "bar.pulse",
update_period: float = 0.1,
disable: bool = False,
show_speed: bool = True,
):
"""
Track progress by iterating over a sequence.

This function is very similar to rich.progress.track(), but with some customizations.

:param sequence: An iterable (must support "len") you wish to iterate over.
:param description: Description of the task to show next to the progress bar. Defaults to "Working".
:param total: Total number of steps. Default is len(sequence).
:param auto_refresh: Automatic refresh. Disable to force a refresh after each iteration. Default is True.
:param transient: Clear the progress on exit. Defaults to False.
:param get_time: A callable that gets the current time, or None to use Console.get_time. Defaults to None.
:param console: Console to write to. Default creates an internal Console instance.
:param refresh_per_second: Number of times per second to refresh the progress information. Defaults to 10.
:param style: Style for the bar background. Defaults to "bar.back".
:param complete_style: Style for the completed bar. Defaults to "bar.complete".
:param finished_style: Style for a finished bar. Defaults to "bar.finished".
:param pulse_style: Style for pulsing bars. Defaults to "bar.pulse".
:param update_period: Minimum time (in seconds) between calls to update(). Defaults to 0.1.
:param disable: Disable display of progress.
:param show_speed: Show speed if the total isn't known. Defaults to True.
:return: An iterable of the values in the sequence.
"""

self.sequence = sequence
self.total = total
self.description = description
self.update_period = update_period
self.task = None

self.columns: List[ProgressColumn] = (
[TextColumn("[progress.description]{task.description}")] if description else []
)
self.columns.extend(
(
BarColumn(
style=style,
complete_style=complete_style,
finished_style=finished_style,
pulse_style=pulse_style,
),
TaskProgressColumn(show_speed=show_speed),
IterationsColumn(),
SeparatorColumn(),
TimeElapsedColumn(),
SeparatorColumn(disable_if_no_total=True), # disable because time remaining will be empty
TimeRemainingColumn(elapsed_when_finished=True),
nikita-savelyevv marked this conversation as resolved.
Show resolved Hide resolved
)
)
self.progress = Progress(
*self.columns,
auto_refresh=auto_refresh,
console=console,
transient=transient,
get_time=get_time,
refresh_per_second=refresh_per_second or 10,
disable=disable,
)

def __iter__(self) -> Iterable[ProgressType]:
with self.progress:
yield from self.progress.track(
self.sequence, total=self.total, description=self.description, update_period=self.update_period
)

def __enter__(self):
self.progress.start()
self.task = self.progress.add_task(self.description, total=self.total)
return self

def __exit__(self, *args):
self.task = None
self.progress.stop()
7 changes: 3 additions & 4 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
from itertools import islice
from typing import Any, Dict, TypeVar

from tqdm.auto import tqdm

from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.data.dataset import Dataset
Expand Down Expand Up @@ -60,10 +59,10 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
if self.stat_subset_size is not None
else None
)
for input_data in tqdm(
for input_data in track(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=total,
desc="Statistics collection",
description="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/bias_correction/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm.auto import tqdm

from nncf import Dataset
from nncf import nncf_logger
Expand All @@ -26,6 +25,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -159,8 +159,8 @@ def apply(
# for which we will create a subgraph for inference and collection of statistics.
subgraphs_data = [self._get_subgraph_data_for_node(node, nncf_graph) for node in nodes_with_bias]

for position, (node, subgraph_data) in tqdm(
list(enumerate(zip(nodes_with_bias, subgraphs_data))), desc="Applying Bias correction"
for position, (node, subgraph_data) in track(
list(enumerate(zip(nodes_with_bias, subgraphs_data))), description="Applying Bias correction"
):
node_name = node.node_name

Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/channel_alignment/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Dict, List, Optional, Tuple, TypeVar

import numpy as np
from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import CommandCreatorFactory
Expand All @@ -23,6 +22,7 @@
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -104,7 +104,7 @@ def apply(
def filter_func(point: StatisticPoint) -> bool:
return self._algorithm_key in point.algorithm_to_tensor_collectors and point.target_point == target_point

for conv_in, add_in, conv_out in tqdm(self._get_node_pairs(graph), desc="Channel alignment"):
for conv_in, add_in, conv_out in track(self._get_node_pairs(graph), description="Channel alignment"):
target_point, node_in = self._get_target_point_and_node_in(conv_in, add_in)
tensor_collectors = list(
statistic_points.get_algo_statistics_for_node(node_in.node_name, filter_func, self._algorithm_key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union

from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import EngineFactory
from nncf.common.factory import ModelTransformerFactory
Expand All @@ -22,6 +20,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -139,7 +138,7 @@ def apply(
# for which we should update bias and new bias values.
node_and_new_bias_value = []

for node, bias_value in tqdm(node_and_bias_value, desc="Applying Fast Bias correction"):
for node, bias_value in track(node_and_bias_value, description="Applying Fast Bias correction"):
node_name = node.node_name

if not self._backend_entity.is_quantized_weights(node, graph):
Expand Down
5 changes: 2 additions & 3 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,14 @@
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, TypeVar

from tqdm.auto import tqdm

from nncf import Dataset
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
Expand Down Expand Up @@ -111,7 +110,7 @@ def apply(
node_groups = self._group_nodes_by_source(nodes_to_smooth_data, graph)

best_scale = None
for group_id, nodes in tqdm(node_groups.items(), desc="Applying Smooth Quant"):
for group_id, nodes in track(node_groups.items(), description="Applying Smooth Quant"):
best_ratio = 0.0
empty_statistic = False
for node_to_smooth in nodes:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def find_version(*file_paths):
# Using 2.x versions of pyparsing seems to fix the issue.
# Ticket: 69520
"pyparsing<3.0",
"rich>=13.5.2",
nikita-savelyevv marked this conversation as resolved.
Show resolved Hide resolved
"scikit-learn>=0.24.0",
"scipy>=1.3.2, <1.11",
"texttable>=1.6.3",
Expand Down