Skip to content

Commit

Permalink
Fix (graph/equalize): cleanup and device management
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 12, 2024
1 parent e0d78a6 commit 007c0dc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 22 deletions.
39 changes: 18 additions & 21 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _cross_layer_equalization(
# If equalization criteria are not met, we return a scalar one to indicate that no equalization
# has been performed
def _no_equalize():
return torch.tensor(1., dtype=dtype, device=device)
return torch.tensor(1., dtype=dtype)

# If a module has `allocate_params` attribute, we must load the weights following that method

Expand All @@ -410,7 +410,6 @@ def _no_equalize():
act_sink_axes = {}
act_sources_axes = {}
single_module = region.get_module_from_name(next(iter(region.sinks_names)))
device = next(single_module.parameters()).device
dtype = next(single_module.parameters()).dtype

max_shape_srcs = 0
Expand Down Expand Up @@ -474,16 +473,15 @@ def _no_equalize():
return _no_equalize()

scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = {name: transpose(m.weight, axis) for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype)
sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype)
sink_weights = {name: transpose(m.weight.cpu(), axis) for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=dtype)
sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=dtype)
for k, v in sink_weights.items():
# Sinks can be partially equalized, thus we need to select
# only the channels we are interested in
indexes = region.sinks[k]
# Compute the range of the channels we need to equalize
weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end]
weight_range = weight_range.to(device)
# Compute the numbers of channels we are equalizing
channel_range = indexes.end - indexes.start
# Use the offset and the range to update the correct range in the sinks
Expand All @@ -494,18 +492,18 @@ def _no_equalize():
# weight equalization
if merge_bias:
src_weights = {
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias)
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias).cpu()
for name, (m, axis) in src_axes.items()}
else:
src_weights = {name: transpose(m.weight, axis) for name, (m, axis) in src_axes.items()}
src_weights = {
name: transpose(m.weight.cpu(), axis) for name, (m, axis) in src_axes.items()}
for k, v in src_weights.items():
# Srcs are always fully equalized, thus we simply need to apply the offset to position them
# correctly with respect to the other srcs matrices.
indexes = region.srcs[k]
channel_start = indexes.offset + indexes.start
channel_end = indexes.offset + indexes.end
weight_range = scale_fn(v.reshape(v.size(0), -1))
weight_range = weight_range.to(device)
srcs_range[channel_start:channel_end] = torch.max(
srcs_range[channel_start:channel_end], weight_range)
if list_of_act_val is not None:
Expand All @@ -517,11 +515,11 @@ def _no_equalize():
list_of_act_val = list_of_act_val = [
transpose(act_val, act_axis) for act_val in list_of_act_val]
srcs_range_act = scale_fn(
torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1))
srcs_range_act = srcs_range_act.to(device=device)
torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val],
1)).cpu()

if list_of_act_val is not None:
if co_optimize_act_weights:
if co_optimize_act_weights and len(src_axes) > 0:
srcs_range = .5 * srcs_range + .5 * srcs_range_act
else:
srcs_range = srcs_range_act
Expand All @@ -537,18 +535,19 @@ def _no_equalize():
# which is the no-op equivalent for equalization.
channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON)
sinks_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range)
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), sinks_range)
srcs_range = torch.where(
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range)
channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), srcs_range)

srcs_range = torch.pow(srcs_range, alpha)
sinks_range = torch.pow(sinks_range, 1 - alpha)
scaling_factors = srcs_range / sinks_range
inverse_scaling_factors = torch.reciprocal(scaling_factors)

if list_of_act_val is not None and list_of_insert_mul_node_fn is not None:
device = list_of_act_val[0].device
for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn):
insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis)
insert_mul_node_fn(inverse_scaling_factors.to(device=device), act_val_shape, act_axis)
if len(src_axes) > 0:
for name, (module, axis) in src_axes.items():
module_device = module.weight.device
Expand All @@ -559,31 +558,29 @@ def _no_equalize():
device=module_device)
if hasattr(module, 'bias') and module.bias is not None:
_update_weights(
module,
module.bias.clone() * partial_inverse_scale.view_as(module.bias),
attr='bias')
module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias')
src_broadcast_size = [1] * module.weight.ndim
src_broadcast_size[axis] = module.weight.size(axis)

_update_weights(
module,
module.weight.clone() * torch.reshape(partial_inverse_scale, src_broadcast_size),
module.weight * torch.reshape(partial_inverse_scale, src_broadcast_size),
attr='weight')
for name, (module, axis) in sink_axes.items():
module_device = module.weight.device
sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
indexes = region.sinks[name]
channel_range = indexes.end - indexes.start
partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype)
partial_scaling = torch.ones(module.weight.size(axis), device='cpu', dtype=dtype)
# We replace the scaling factors of the channels we need to equalize, leaving the other to
# one (i.e., no equalization)
partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset +
channel_range]
partial_scaling = partial_scaling.to(device=module_device)
_update_weights(
module,
module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size),
module.weight * torch.reshape(partial_scaling, sink_broadcast_size),
attr='weight')

# If a module has `offload_params` attribute, we must offload the weights following that method
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def update_batch(self, module, input, current_layer):

def single_layer_update(self, percdamp=.01):
if hasattr(self.layer, 'allocate_params'):
self.layer.allocate_params(self.layer, 'cuda')
self.layer.allocate_params(self.layer)
weight = self.layer.weight.data
dev = weight.device

Expand Down

0 comments on commit 007c0dc

Please sign in to comment.