diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index c1c64a317..e6538421e 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -398,7 +398,7 @@ def _cross_layer_equalization( # If equalization criteria are not met, we return a scalar one to indicate that no equalization # has been performed def _no_equalize(): - return torch.tensor(1., dtype=dtype, device=device) + return torch.tensor(1., dtype=dtype) # If a module has `allocate_params` attribute, we must load the weights following that method @@ -410,7 +410,6 @@ def _no_equalize(): act_sink_axes = {} act_sources_axes = {} single_module = region.get_module_from_name(next(iter(region.sinks_names))) - device = next(single_module.parameters()).device dtype = next(single_module.parameters()).dtype max_shape_srcs = 0 @@ -474,16 +473,15 @@ def _no_equalize(): return _no_equalize() scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = {name: transpose(m.weight, axis) for name, (m, axis) in sink_axes.items()} - srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype) - sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype) + sink_weights = {name: transpose(m.weight.cpu(), axis) for name, (m, axis) in sink_axes.items()} + srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=dtype) + sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=dtype) for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in indexes = region.sinks[k] # Compute the range of the channels we need to equalize weight_range = scale_fn(v.reshape(v.size(0), -1))[indexes.start:indexes.end] - weight_range = weight_range.to(device) # Compute the numbers of channels we are equalizing channel_range = indexes.end - indexes.start # Use the offset and the range to update the correct range in the sinks @@ -494,10 +492,11 @@ def _no_equalize(): # weight equalization if merge_bias: src_weights = { - name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias) + name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias).cpu() for name, (m, axis) in src_axes.items()} else: - src_weights = {name: transpose(m.weight, axis) for name, (m, axis) in src_axes.items()} + src_weights = { + name: transpose(m.weight.cpu(), axis) for name, (m, axis) in src_axes.items()} for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. @@ -505,7 +504,6 @@ def _no_equalize(): channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end weight_range = scale_fn(v.reshape(v.size(0), -1)) - weight_range = weight_range.to(device) srcs_range[channel_start:channel_end] = torch.max( srcs_range[channel_start:channel_end], weight_range) if list_of_act_val is not None: @@ -517,11 +515,11 @@ def _no_equalize(): list_of_act_val = list_of_act_val = [ transpose(act_val, act_axis) for act_val in list_of_act_val] srcs_range_act = scale_fn( - torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1)) - srcs_range_act = srcs_range_act.to(device=device) + torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], + 1)).cpu() if list_of_act_val is not None: - if co_optimize_act_weights: + if co_optimize_act_weights and len(src_axes) > 0: srcs_range = .5 * srcs_range + .5 * srcs_range_act else: srcs_range = srcs_range_act @@ -537,9 +535,9 @@ def _no_equalize(): # which is the no-op equivalent for equalization. channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) sinks_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), sinks_range) + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), sinks_range) srcs_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device=device), srcs_range) + channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), srcs_range) srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) @@ -547,8 +545,9 @@ def _no_equalize(): inverse_scaling_factors = torch.reciprocal(scaling_factors) if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: + device = list_of_act_val[0].device for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn): - insert_mul_node_fn(inverse_scaling_factors, act_val_shape, act_axis) + insert_mul_node_fn(inverse_scaling_factors.to(device=device), act_val_shape, act_axis) if len(src_axes) > 0: for name, (module, axis) in src_axes.items(): module_device = module.weight.device @@ -559,15 +558,13 @@ def _no_equalize(): device=module_device) if hasattr(module, 'bias') and module.bias is not None: _update_weights( - module, - module.bias.clone() * partial_inverse_scale.view_as(module.bias), - attr='bias') + module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias') src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) _update_weights( module, - module.weight.clone() * torch.reshape(partial_inverse_scale, src_broadcast_size), + module.weight * torch.reshape(partial_inverse_scale, src_broadcast_size), attr='weight') for name, (module, axis) in sink_axes.items(): module_device = module.weight.device @@ -575,7 +572,7 @@ def _no_equalize(): sink_broadcast_size[axis] = module.weight.size(axis) indexes = region.sinks[name] channel_range = indexes.end - indexes.start - partial_scaling = torch.ones(module.weight.size(axis), device=device, dtype=dtype) + partial_scaling = torch.ones(module.weight.size(axis), device='cpu', dtype=dtype) # We replace the scaling factors of the channels we need to equalize, leaving the other to # one (i.e., no equalization) partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + @@ -583,7 +580,7 @@ def _no_equalize(): partial_scaling = partial_scaling.to(device=module_device) _update_weights( module, - module.weight.clone() * torch.reshape(partial_scaling, sink_broadcast_size), + module.weight * torch.reshape(partial_scaling, sink_broadcast_size), attr='weight') # If a module has `offload_params` attribute, we must offload the weights following that method diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index e76f6c18f..907633e0f 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -181,7 +181,7 @@ def update_batch(self, module, input, current_layer): def single_layer_update(self, percdamp=.01): if hasattr(self.layer, 'allocate_params'): - self.layer.allocate_params(self.layer, 'cuda') + self.layer.allocate_params(self.layer) weight = self.layer.weight.data dev = weight.device