Skip to content

Commit

Permalink
Fix (equalize): align cross layer equalization with channel splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 30, 2024
1 parent 72461de commit c117bd7
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class EqualizationIndexes:

# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasTuple:
weight: nn.Module = None
bias: nn.Module = None
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None


# Required for being hashable
Expand Down Expand Up @@ -359,16 +359,16 @@ def _combine_weights_bias(
return weight_bias


def transpose(module: torch.nn.Module, axis: int):
def transpose(tensor: torch.Tensor, axis: int):
"""
Given a module and an axis, this function re-arranges the module's weights so that the axis and
Given a tensor and an axis, this function re-arranges the tensor so that the axis and
the first dimension are swapped.
"""
shape = list(range(module.weight.ndim))
shape = list(range(tensor.ndim))
axis = shape[axis]
shape.insert(0, axis)
del shape[axis + 1]
return module.weight.permute(shape)
return tensor.permute(shape)


def _cross_layer_equalization(
Expand Down Expand Up @@ -430,7 +430,7 @@ def _no_equalize():
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
# For sinks, we only need to modify the weight but not the bias
module = WeightBiasTuple(module.in_proj_weight)
module = WeightBiasWrapper(module.in_proj_weight)
elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None:
return _no_equalize()
sink_axes[name] = (module, axis)
Expand All @@ -452,12 +452,12 @@ def _no_equalize():

# Check if any of the axis is None, which means that the module is not supported.
# In that case, do not perform graph equalization
axes_to_check = [*src_axes.values(), *sink_axes.values()]
axes_to_check = [axis for _, axis in list(src_axes.values()) + list(sink_axes.values())]
if None in axes_to_check:
return _no_equalize()

scale_fn = _select_scale_computation_fn(scale_computation_type)
sink_weights = {name: transpose(m, axis) for name, (m, axis) in sink_axes.items()}
sink_weights = {name: transpose(m.weight, axis) for name, (m, axis) in sink_axes.items()}
srcs_range = -1 * torch.ones(max_shape_srcs, device=device, dtype=dtype)
sinks_range = -1 * torch.ones(max_shape_sinks, device=device, dtype=dtype)
for k, v in sink_weights.items():
Expand All @@ -480,17 +480,16 @@ def _no_equalize():
shape_0 = list_of_act_val_shapes[0]
if any(shape_0 != shape for shape in list_of_act_val_shapes):
return _no_equalize()
list_of_act_val = [
transpose(WeightBiasTuple(act_val), act_axis) for act_val in list_of_act_val]
list_of_act_val = [transpose(act_val, act_axis) for act_val in list_of_act_val]
srcs_range = scale_fn(
torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1))
else:
if merge_bias:
src_weights = {
name: _combine_weights_bias(transpose(m, axis), bias_shrinkage, m.bias)
name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias)
for name, (m, axis) in src_axes.items()}
else:
src_weights = {name: transpose(m, axis) for name, (m, axis) in src_axes.items()}
src_weights = {name: transpose(m.weight, axis) for name, (m, axis) in src_axes.items()}
for k, v in src_weights.items():
# Srcs are always fully equalized, thus we simply need to apply the offset to position them
# correctly with respect to the other srcs matrices.
Expand Down Expand Up @@ -562,7 +561,7 @@ def _no_equalize():


def _update_weights(original_module, new_value, attr='weight'):
if isinstance(original_module, WeightBiasTuple):
if isinstance(original_module, WeightBiasWrapper):
setattr(getattr(original_module, attr), 'data', new_value)
else:
setattr(original_module, attr, nn.Parameter(new_value))
Expand Down Expand Up @@ -645,7 +644,7 @@ def get_weight_sink(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance(
weight = WeightBiasWrapper(module.in_proj_weight).weight if isinstance(
module, nn.MultiheadAttention) else module.weight
axis = _get_input_axis(module)
weight = transpose(weight, axis)
Expand Down

0 comments on commit c117bd7

Please sign in to comment.