Skip to content

Commit

Permalink
Feat (graph/equalize): upcast during equalization computation (#970)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jun 21, 2024
1 parent df1a137 commit b4e9287
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,11 @@ def _no_equalize():
return _no_equalize()

scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = {name: transpose(m.weight.cpu(), axis) for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=dtype)
sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=dtype)
sink_weights = {
name: transpose(m.weight.cpu().to(torch.float32), axis)
for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=torch.float32)
sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=torch.float32)
for k, v in sink_weights.items():
# Sinks can be partially equalized, thus we need to select
# only the channels we are interested in
Expand All @@ -493,11 +495,13 @@ def _no_equalize():
# weight equalization
if merge_bias:
src_weights = {
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias).cpu()
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage,
m.bias).cpu().to(torch.float32)
for name, (m, axis) in src_axes.items()}
else:
src_weights = {
name: transpose(m.weight.cpu(), axis) for name, (m, axis) in src_axes.items()}
name: transpose(m.weight.cpu().to(torch.float32), axis)
for name, (m, axis) in src_axes.items()}
for k, v in src_weights.items():
# Srcs are always fully equalized, thus we simply need to apply the offset to position them
# correctly with respect to the other srcs matrices.
Expand All @@ -516,8 +520,10 @@ def _no_equalize():
list_of_act_val = list_of_act_val = [
transpose(act_val, act_axis) for act_val in list_of_act_val]
srcs_range_act = scale_fn(
torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val],
1)).cpu()
torch.cat([
act_val.reshape(act_val.size(0), -1).cpu().to(torch.float32)
for act_val in list_of_act_val],
1))

if list_of_act_val is not None:
if co_optimize_act_weights and len(src_axes) > 0:
Expand All @@ -536,9 +542,9 @@ def _no_equalize():
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), sinks_range)
channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), srcs_range)
channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
Expand All @@ -548,15 +554,16 @@ def _no_equalize():
if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
device = list_of_act_val[0].device
for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn):
insert_mul_node_fn(inverse_scaling_factors.to(device=device), act_val_shape, act_axis)
insert_mul_node_fn(
inverse_scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis)
if len(src_axes) > 0:
for name, (module, axis) in src_axes.items():
module_device = module.weight.device
indexes = region.srcs[name]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
partial_inverse_scale = inverse_scaling_factors[channel_start:channel_end].to(
device=module_device)
device=module_device, dtype=dtype)
if hasattr(module, 'bias') and module.bias is not None:
_update_weights(
module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias')
Expand All @@ -578,7 +585,7 @@ def _no_equalize():
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
partial_scaling = partial_scaling.to(device=module_device)
partial_scaling = partial_scaling.to(device=module_device, dtype=dtype)
_update_weights(
module,
module.weight * torch.reshape(partial_scaling, sink_broadcast_size),
Expand Down Expand Up @@ -983,7 +990,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
dtype = x.dtype
input_scales = self.scale_fn(x.to(torch.float32), dim=batch_dim).to(dtype)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
Expand Down

0 comments on commit b4e9287

Please sign in to comment.