diff --git a/nncf/common/logging/track_progress.py b/nncf/common/logging/track_progress.py index 90455fab2ef..ce6fdf183ab 100644 --- a/nncf/common/logging/track_progress.py +++ b/nncf/common/logging/track_progress.py @@ -18,6 +18,7 @@ from rich.progress import ProgressColumn from rich.progress import ProgressType from rich.progress import Task +from rich.progress import TaskID from rich.progress import TaskProgressColumn from rich.progress import TextColumn from rich.progress import TimeElapsedColumn @@ -59,6 +60,65 @@ def render(self, task: "Task") -> Text: return Text(text._text[0], style=INTEL_BLUE_COLOR) +class WeightedProgress(Progress): + """ + A class to perform a weighted progress tracking. + """ + + def update(self, task_id: TaskID, **kwargs) -> None: + task = self._tasks[task_id] + + advance = kwargs.get("advance", None) + if advance is not None: + kwargs["advance"] = self.weighted_advance(task, advance) + + completed = kwargs.get("completed", None) + if completed is not None: + kwargs["completed"] = self.get_weighted_completed(task, completed) + + super().update(task_id, **kwargs) + + def advance(self, task_id: TaskID, advance: float = 1) -> None: + if advance is not None: + task = self._tasks[task_id] + advance = self.weighted_advance(task, advance) + super().advance(task_id, advance) + + def reset(self, task_id: TaskID, **kwargs) -> None: + task = self._tasks[task_id] + + completed = kwargs.get("completed", None) + if completed is not None: + kwargs["completed"] = self.get_weighted_completed(task, completed) + + super().reset(task_id, **kwargs) + + if completed == 0: + task.fields["completed_steps"] = 0 + + @staticmethod + def weighted_advance(task: Task, advance: float) -> float: + """ + Perform weighted advancement based on an integer step value. + """ + if advance % 1 != 0: + raise Exception(f"Unexpected `advance` value: {advance}.") + advance = int(advance) + current_step = task.fields["completed_steps"] + weighted_advance = sum(task.fields["weights"][current_step : current_step + advance]) + task.fields["completed_steps"] = current_step + advance + return weighted_advance + + @staticmethod + def get_weighted_completed(task: Task, completed: float) -> float: + """ + Get weighted `completed` corresponding to an integer `completed` field. + """ + if completed % 1 != 0: + raise Exception(f"Unexpected `completed` value: {completed}.") + return sum(task.fields["weights"][: int(completed)]) + + class track: def __init__( self, @@ -77,6 +137,7 @@ def __init__( update_period: float = 0.1, disable: bool = False, show_speed: bool = True, + weights: Optional[List[float]] = None, ): """ Track progress by iterating over a sequence. @@ -98,11 +159,14 @@ def __init__( :param update_period: Minimum time (in seconds) between calls to update(). Defaults to 0.1. :param disable: Disable display of progress. :param show_speed: Show speed if the total isn't known. Defaults to True. + :param weights: List of progress weights for each sequence element. Weights should be proportional to the time + it takes to process sequence elements. Useful when processing time is strongly non-uniform. :return: An iterable of the values in the sequence. """ self.sequence = sequence - self.total = total + self.weights = weights + self.total = sum(self.weights) if self.weights is not None else total self.description = description self.update_period = update_period self.task = None @@ -120,7 +184,13 @@ def __init__( bar_width=None, ), TaskProgressColumn(show_speed=show_speed), - IterationsColumn(), + ) + ) + # Do not add iterations column for weighted tracking because steps will be in weighted coordinates + if self.weights is None: + self.columns.append(IterationsColumn()) + self.columns.extend( + ( SeparatorColumn(), TimeElapsedColumnWithStyle(), SeparatorColumn(disable_if_no_total=True), # disable because time remaining will be empty @@ -130,7 +200,8 @@ def __init__( disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0) - self.progress = Progress( + progress_cls = Progress if weights is None else WeightedProgress + self.progress = progress_cls( *self.columns, auto_refresh=auto_refresh, console=console, @@ -141,16 +212,24 @@ def __init__( ) def __iter__(self) -> Iterable[ProgressType]: - with self.progress: + with self: yield from self.progress.track( - self.sequence, total=self.total, description=self.description, update_period=self.update_period + self.sequence, + total=self.total, + task_id=self.task, + description=self.description, + update_period=self.update_period, ) def __enter__(self): - self.progress.start() - self.task = self.progress.add_task(self.description, total=self.total) - return self + kwargs = {} + if self.weights is not None: + kwargs["weights"] = self.weights + kwargs["completed_steps"] = 0 + self.task = self.progress.add_task(self.description, total=self.total, **kwargs) + return self.progress.__enter__() def __exit__(self, *args): + self.progress.__exit__(*args) + self.progress.remove_task(self.task) self.task = None - self.progress.stop() diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index e509ce7d11c..3f789f2c39c 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -405,12 +405,13 @@ def apply( # Sort weight params to start compression with the bigger constants. This lowers peak memory footprint. all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True) + all_weight_sizes = [wp.num_weights for wp in all_weight_params] # Compress model using weight compression parameters transformed_model = self._backend_entity.transform_model( model, graph, - track(all_weight_params, description="Applying Weight Compression"), + track(all_weight_params, description="Applying Weight Compression", weights=all_weight_sizes), scales, zero_points, )