Skip to content

Commit

Permalink
More correct hessian trace calculation (#2159)
Browse files Browse the repository at this point in the history
### Changes

Use absolute value of avg_total_trace in denominator, as done in [PyHessian:](https://github.com/amirgholami/PyHessian/blob/master/pyhessian/hessian.py#L186C22-L186C22)

### Reason for changes

more robust mixed precision algo

### Related tickets

resolves #2155 

### Tests

hawq-related tests
  • Loading branch information
ljaljushkin authored Sep 26, 2023
1 parent def8009 commit 36f45a2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nncf/torch/quantization/hessian_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_average_traces(self, max_iter=500, tolerance=1e-5) -> Tensor:
mean_avg_traces_per_param = self._get_mean(avg_traces_per_iter)
mean_avg_total_trace = torch.sum(mean_avg_traces_per_param)

diff_avg = abs(mean_avg_total_trace - avg_total_trace) / (avg_total_trace + self._diff_eps)
diff_avg = abs(mean_avg_total_trace - avg_total_trace) / (abs(avg_total_trace) + self._diff_eps)
if diff_avg < tolerance:
return mean_avg_traces_per_param
avg_total_trace = mean_avg_total_trace
Expand Down

0 comments on commit 36f45a2

Please sign in to comment.