Skip to content

Commit

Permalink
Add weighted progress tracking for weight compression (#2892)
Browse files Browse the repository at this point in the history
### Changes

During nncf weight compression, `rich` progress bar is used to display
the progress. In this PR, progress bar is changed to be weighted
according to model weights. With these changes, each weight contributes
proportional amount of percent to the progress bar.

Iteration number was removed from weight compression progress bar to
avoid confusion between different speeds in percent and iteration
coordinates. For example now a single weight might contribute 5-10% to
the whole progress.

### Reason for changes

The time it takes to compress a weight is roughly proportional to its
size, so incrementing the progress by 1 for each weight is not ideal.

Especially after #2803 when weight sorting was added. Now, the largest
weights come first and the smallest ones are at the end of the
compression. This leads to misleading time estimation when progress
contribution from every weight is equal.

Weights sizes for tinyllama-1.1b for reference:

![weight_size_hist](https://github.com/user-attachments/assets/30ba1e1b-0fc5-4d6b-84db-948362672bf2)


![weight_size_cumsum_hist](https://github.com/user-attachments/assets/b00e79e8-5000-44a4-97a5-4102c9aed0ae)
  • Loading branch information
nikita-savelyevv authored Aug 26, 2024
1 parent 1104f1b commit c500822
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 10 deletions.
97 changes: 88 additions & 9 deletions nncf/common/logging/track_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from rich.progress import ProgressColumn
from rich.progress import ProgressType
from rich.progress import Task
from rich.progress import TaskID
from rich.progress import TaskProgressColumn
from rich.progress import TextColumn
from rich.progress import TimeElapsedColumn
Expand Down Expand Up @@ -59,6 +60,65 @@ def render(self, task: "Task") -> Text:
return Text(text._text[0], style=INTEL_BLUE_COLOR)


class WeightedProgress(Progress):
"""
A class to perform a weighted progress tracking.
"""

def update(self, task_id: TaskID, **kwargs) -> None:
task = self._tasks[task_id]

advance = kwargs.get("advance", None)
if advance is not None:
kwargs["advance"] = self.weighted_advance(task, advance)

completed = kwargs.get("completed", None)
if completed is not None:
kwargs["completed"] = self.get_weighted_completed(task, completed)

super().update(task_id, **kwargs)

def advance(self, task_id: TaskID, advance: float = 1) -> None:
if advance is not None:
task = self._tasks[task_id]
advance = self.weighted_advance(task, advance)
super().advance(task_id, advance)

def reset(self, task_id: TaskID, **kwargs) -> None:
task = self._tasks[task_id]

completed = kwargs.get("completed", None)
if completed is not None:
kwargs["completed"] = self.get_weighted_completed(task, completed)

super().reset(task_id, **kwargs)

if completed == 0:
task.fields["completed_steps"] = 0

@staticmethod
def weighted_advance(task: Task, advance: float) -> float:
"""
Perform weighted advancement based on an integer step value.
"""
if advance % 1 != 0:
raise Exception(f"Unexpected `advance` value: {advance}.")
advance = int(advance)
current_step = task.fields["completed_steps"]
weighted_advance = sum(task.fields["weights"][current_step : current_step + advance])
task.fields["completed_steps"] = current_step + advance
return weighted_advance

@staticmethod
def get_weighted_completed(task: Task, completed: float) -> float:
"""
Get weighted `completed` corresponding to an integer `completed` field.
"""
if completed % 1 != 0:
raise Exception(f"Unexpected `completed` value: {completed}.")
return sum(task.fields["weights"][: int(completed)])


class track:
def __init__(
self,
Expand All @@ -77,6 +137,7 @@ def __init__(
update_period: float = 0.1,
disable: bool = False,
show_speed: bool = True,
weights: Optional[List[float]] = None,
):
"""
Track progress by iterating over a sequence.
Expand All @@ -98,11 +159,14 @@ def __init__(
:param update_period: Minimum time (in seconds) between calls to update(). Defaults to 0.1.
:param disable: Disable display of progress.
:param show_speed: Show speed if the total isn't known. Defaults to True.
:param weights: List of progress weights for each sequence element. Weights should be proportional to the time
it takes to process sequence elements. Useful when processing time is strongly non-uniform.
:return: An iterable of the values in the sequence.
"""

self.sequence = sequence
self.total = total
self.weights = weights
self.total = sum(self.weights) if self.weights is not None else total
self.description = description
self.update_period = update_period
self.task = None
Expand All @@ -120,7 +184,13 @@ def __init__(
bar_width=None,
),
TaskProgressColumn(show_speed=show_speed),
IterationsColumn(),
)
)
# Do not add iterations column for weighted tracking because steps will be in weighted coordinates
if self.weights is None:
self.columns.append(IterationsColumn())
self.columns.extend(
(
SeparatorColumn(),
TimeElapsedColumnWithStyle(),
SeparatorColumn(disable_if_no_total=True), # disable because time remaining will be empty
Expand All @@ -130,7 +200,8 @@ def __init__(

disable = disable or (hasattr(sequence, "__len__") and len(sequence) == 0)

self.progress = Progress(
progress_cls = Progress if weights is None else WeightedProgress
self.progress = progress_cls(
*self.columns,
auto_refresh=auto_refresh,
console=console,
Expand All @@ -141,16 +212,24 @@ def __init__(
)

def __iter__(self) -> Iterable[ProgressType]:
with self.progress:
with self:
yield from self.progress.track(
self.sequence, total=self.total, description=self.description, update_period=self.update_period
self.sequence,
total=self.total,
task_id=self.task,
description=self.description,
update_period=self.update_period,
)

def __enter__(self):
self.progress.start()
self.task = self.progress.add_task(self.description, total=self.total)
return self
kwargs = {}
if self.weights is not None:
kwargs["weights"] = self.weights
kwargs["completed_steps"] = 0
self.task = self.progress.add_task(self.description, total=self.total, **kwargs)
return self.progress.__enter__()

def __exit__(self, *args):
self.progress.__exit__(*args)
self.progress.remove_task(self.task)
self.task = None
self.progress.stop()
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,13 @@ def apply(

# Sort weight params to start compression with the bigger constants. This lowers peak memory footprint.
all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True)
all_weight_sizes = [wp.num_weights for wp in all_weight_params]

# Compress model using weight compression parameters
transformed_model = self._backend_entity.transform_model(
model,
graph,
track(all_weight_params, description="Applying Weight Compression"),
track(all_weight_params, description="Applying Weight Compression", weights=all_weight_sizes),
scales,
zero_points,
)
Expand Down

0 comments on commit c500822

Please sign in to comment.