Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (graph/equalize): refactor for act equalization #787

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 94 additions & 118 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from dataclasses import field
from functools import partial
Expand Down Expand Up @@ -147,7 +149,7 @@ def __exit__(self, type, value, traceback):

def dict_name_to_module(model, regions):
name_to_module: Dict[str, torch.nn.Module] = {}
# name_set = {name for region in regions for module_set in region for name in module_set}

name_set = set()
for region in regions:
for name in region.srcs:
Expand Down Expand Up @@ -689,11 +691,67 @@ def apply(self,
return graph_model


class LayerwiseActivationEqualization(GraphTransform):
class ActivationEqualization(GraphTransform, ABC):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
super(LayerwiseActivationEqualization, self).__init__()
def __init__(
self, model: Union[nn.Module, GraphModule], scale_computation_type: str = 'maxabs'):
self.model = model
self.scale_computation_type = scale_computation_type

@abstractmethod
def setup(self):
pass

@abstractmethod
def insert_mul_node(self):
pass

def create_mul_node(self, scale, shape, axis, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
return mul_factor

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model)


class LayerwiseActivationEqualization(ActivationEqualization):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
Expand All @@ -703,7 +761,6 @@ def __init__(self, model, scale_computation_type: str = 'maxabs'):
self.find_module(model, regions)
self.regions = regions

self.scale_computation_type = scale_computation_type
if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
Expand Down Expand Up @@ -751,79 +808,34 @@ def apply(self, alpha):
alpha=alpha))
return scale_factors

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.model)

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, region, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
rewriter = ModuleInstanceToModuleInstance(
region, EqualizedModule(scale_module=mul_factor, layer=region))
rewriter.apply(self.model)


class GraphActivationEqualization(GraphTransform):
class GraphActivationEqualization(ActivationEqualization):

def __init__(
self, model, add_mul_node, layerwise=False, scale_computation_type: str = 'maxabs'):
super(GraphActivationEqualization, self).__init__()
self.graph_model = model
self,
model: GraphModule,
add_mul_node: bool = False,
scale_computation_type: str = 'maxabs'):
super(GraphActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.layerwise = layerwise
if self.layerwise:
self.add_mul_node = True
else:
self.add_mul_node = add_mul_node
if self.layerwise:
regions = []
self.find_module(model, regions)
self.regions = regions
else:
self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True)
self.add_mul_node = add_mul_node
self.regions = _extract_regions(model, add_mul_node=add_mul_node, return_acts=True)

self.scale_computation_type = scale_computation_type
if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
self.scale_fn = _channel_range

def setup(self):
name_to_module = dict_name_to_module(self.graph_model, self.regions)
name_to_module = dict_name_to_module(self.model, self.regions)
# Select only regions with activation to equalize through.
# If a region has multiple scale varying activation, must also be dropped
# because we can't propagate scaling factors
Expand All @@ -835,29 +847,30 @@ def setup(self):
_scale_varying_activations)
for act_name in region.acts]):
regions_to_drop.append(region)
else:
# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
act_module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(act_module, hook_fn)
ModuleInstanceToModuleInstance(act_module, new_instance).apply(self.graph_model)
self.hooks.append(new_instance)
continue

# We assume that the entire region has a unique batch_dim
batch_dim = 0
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
for name in region.srcs + region.sinks:
module = name_to_module[name]
if hasattr(module, 'batch_first'):
batch_dim = 0 if module.batch_first else 1
for name in region_to_search:
module = name_to_module[name]
use_inp = True if region_to_search == region.sinks else False
hook_fn = partial(
self.forward_stats_hook, name=name, batch_dim=batch_dim, use_inp=use_inp)
new_instance = KwargsForwardHook(module, hook_fn)
ModuleInstanceToModuleInstance(module, new_instance).apply(self.model)
self.hooks.append(new_instance)

self.regions = [x for x in self.regions if x not in regions_to_drop]

def apply(self, alpha):
scale_factors = []
self.remove_hooks()
name_to_module = dict_name_to_module(self.graph_model, self.regions)
name_to_module = dict_name_to_module(self.model, self.regions)
for region in self.regions:
region_to_search = region.sinks if len(region.acts) == 0 else region.acts
if any([self.float_act_map[name] is None for name in region_to_search]):
Expand All @@ -877,7 +890,7 @@ def apply(self, alpha):
# Even though we iterate, this list will always have a single element by definition
list_of_insert_mul_node_fn = []
for act_name in region.acts:
act_node = get_node(self.graph_model, act_name)
act_node = get_node(self.model, act_name)
list_of_insert_mul_node_fn.append(
partial(
self.insert_mul_node,
Expand All @@ -895,46 +908,9 @@ def apply(self, alpha):

return scale_factors

def remove_hooks(self):
for hook in self.hooks:
ModuleInstanceToModuleInstance(hook, hook.module).apply(self.graph_model)

def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **kwargs):
# Check for MHA Cross attention, and if found, skip it
kwargs.update(zip(module.forward.__code__.co_varnames[1:], args[:-1]))
if 'query' in kwargs and 'key' in kwargs and 'value' in kwargs:
if kwargs['query'].data_ptr() != kwargs['key'].data_ptr() != kwargs['value'].data_ptr():
self.float_act_map[name] = None
return

possible_input_kwargs = ['input', 'inp', 'query']
input_kwarg = [x for x in kwargs.keys() if x in possible_input_kwargs][0]
if use_inp:
x = kwargs[input_kwarg]
elif not use_inp:
x = args[-1]

# Extra check for batch_dim
if hasattr(x, 'names') and 'N' in x.names:
batch_dim = x.names.index('N')

self.batch_dim_act_map[name] = batch_dim

input_scales = self.scale_fn(x, dim=batch_dim)
if name not in self.float_act_map:
self.float_act_map[name] = input_scales
else:
self.float_act_map[name] = torch.max(self.float_act_map[name], input_scales)

def insert_mul_node(self, scale, shape, axis, act_node, batch_dim=0):
broadcastable_shape = [1] * len(shape)
broadcastable_shape[axis] = shape[axis]
# Add Batch Dim
broadcastable_shape.insert(batch_dim, 1)
mul_factor = ScaleBias(
num_features=shape[axis], bias=False, runtime_shape=broadcastable_shape)
mul_factor.weight.data = scale
mul_factor = self.create_mul_node(scale, shape, axis, batch_dim)
mul_factor_name = act_node.name + 'act_eq_mul'
self.graph_model.add_module(mul_factor_name, mul_factor)
self.model.add_module(mul_factor_name, mul_factor)
rewriter = InsertModuleCallAfter(mul_factor_name, act_node)
rewriter.apply(self.graph_model)
rewriter.apply(self.model)
Loading