Skip to content

Commit

Permalink
Fix (graph/bias_correction): Fix when layer parameters are offloaded …
Browse files Browse the repository at this point in the history
…to `accelerate` (#962)

* Fix (graph/bias_correction): Fix when layer parameters are offloaded to `accelerate`

* Fix (bias_correction): Typo fix

* Fix (bias_correction): Apply accelerate fix to entire `if/elif` block.

* fix (bias_corr/accelerate): Added comment
  • Loading branch information
nickfraser authored Jul 8, 2024
1 parent 1394889 commit f7d634d
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,18 @@ def apply_correction(self, model):
for name, module in model.named_modules():
if name in self.correction_map.keys():
correction = self.correction_map[name] / self.iterations[name]
# When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter.
if hasattr(module, 'allocate_params'):
module.allocate_params(module)
if module.bias is not None:
module.bias.data += correction
elif self.skip_if_no_bias is False:
# If accelerate is enabled, bias will be on the same execution device as the weights, but won't be managed properly by accelerate
module.register_parameter(
'bias', nn.Parameter(correction).to(module.weight.device))
# Offload params again
if hasattr(module, 'offload_params'):
module.offload_params(module)

def compute_correct_bias(self, module, inp, name):
inp = self.unpack_input(inp)
Expand Down

0 comments on commit f7d634d

Please sign in to comment.