Skip to content

Commit

Permalink
fixed a bug in index mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
VainF committed Jul 18, 2023
1 parent 44b625a commit 5370954
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 26 deletions.
23 changes: 22 additions & 1 deletion examples/torchvision_models/torchvision_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ def my_prune(model, example_inputs, output_transform, model_name):
print("==============Before pruning=================")
print("Model Name: {}".format(model_name))
print(model)

layer_channel_cfg = {}
for module in model.modules():
if module not in pruner.ignored_layers:
#print(module)
if isinstance(module, nn.Conv2d):
layer_channel_cfg[module] = module.out_channels
elif isinstance(module, nn.Linear):
layer_channel_cfg[module] = module.out_features

pruner.step()
if isinstance(
model, VisionTransformer
Expand All @@ -223,7 +233,18 @@ def my_prune(model, example_inputs, output_transform, model_name):
if output_transform:
out = output_transform(out)
print("{} Pruning: ".format(model_name))
print(" Params: %s => %s" % (ori_size, tp.utils.count_params(model)))
params_after_prune = tp.utils.count_params(model)
print(" Params: %s => %s" % (ori_size, params_after_prune))

if 'rcnn' not in model_name and model_name!='ssdlite320_mobilenet_v3_large': # RCNN may return 0 proposals, making some layers unreachable during tracing.
for module, ch in layer_channel_cfg.items():
if isinstance(module, nn.Conv2d):
#print(module.out_channels, layer_channel_cfg[module])
assert int(0.5*layer_channel_cfg[module]) == module.out_channels
elif isinstance(module, nn.Linear):
#print(module.out_features, layer_channel_cfg[module])
assert int(0.5*layer_channel_cfg[module]) == module.out_features

if isinstance(out, (dict,list,tuple)):
print(" Output:")
for o in tp.utils.flatten_as_list(out):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_customized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ def test_customization():

my_linear_pruner = MyLinearPruner()
DG.register_customized_layer(
model.fc2, my_linear_pruner
nn.Linear, my_linear_pruner
)

# 2. Build dependency graph
DG.build_dependency(model, example_inputs=torch.randn(1,128))

# 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters.
pruning_group = DG.get_pruning_group( model.fc1, tp.prune_linear_out_channels, idxs=[0, 1, 6] )
pruning_group = DG.get_pruning_group( model.fc1, my_linear_pruner.prune_out_channels, idxs=[0, 1, 6] )
print(pruning_group)

# 4. execute this group (prune the model)
Expand Down
32 changes: 23 additions & 9 deletions tests/test_importance.py → tests/test_importance_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,52 @@ def test_imp():
rand_imp = random_importance(pruning_group)
print("Random: ", rand_imp)

magnitude_importance = tp.importance.MagnitudeImportance(p=1, group_reduction=None)
magnitude_importance = tp.importance.MagnitudeImportance(p=1, group_reduction=None, normalizer=None)
mag_imp_raw = magnitude_importance(pruning_group)
print("L-1 Norm, No Reduction: ", mag_imp_raw)

magnitude_importance = tp.importance.MagnitudeImportance(p=1)
magnitude_importance = tp.importance.MagnitudeImportance(p=1, normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-1 Norm, Group Mean: ", mag_imp)
assert torch.allclose(mag_imp, mag_imp_raw.mean(0))

magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction=None)
magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction=None, normalizer=None)
mag_imp_raw = magnitude_importance(pruning_group)
print("L-2 Norm, No Reduction: ", mag_imp_raw)

magnitude_importance = tp.importance.MagnitudeImportance(p=2)
magnitude_importance = tp.importance.MagnitudeImportance(p=2, normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-2 Norm, Group Mean: ", mag_imp)
assert torch.allclose(mag_imp, mag_imp_raw.mean(0))

magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='sum')
magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='sum', normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-2 Norm, Group Sum: ", mag_imp)
assert torch.allclose(mag_imp, mag_imp_raw.sum(0))

bn_scale_importance = tp.importance.BNScaleImportance()
magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='max', normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-2 Norm, Group Max: ", mag_imp)
assert torch.allclose(mag_imp, mag_imp_raw.max(0)[0])

magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='gate', normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-2 Norm, Group Gate: ", mag_imp)
assert torch.allclose(mag_imp, mag_imp_raw[-1])

magnitude_importance = tp.importance.MagnitudeImportance(p=2, group_reduction='prod', normalizer=None)
mag_imp = magnitude_importance(pruning_group)
print("L-2 Norm, Group Prod: ", mag_imp)
print(mag_imp, torch.prod(mag_imp_raw, dim=0))
assert torch.allclose(mag_imp, torch.prod(mag_imp_raw, dim=0))

bn_scale_importance = tp.importance.BNScaleImportance(normalizer=None)
bn_imp = bn_scale_importance(pruning_group)
print("BN Scaling, Group mean: ", bn_imp)

lamp_importance = tp.importance.LAMPImportance()
lamp_importance = tp.importance.LAMPImportance(normalizer=None)
lamp_imp = lamp_importance(pruning_group)
print("LAMP: ", lamp_imp)
assert torch.allclose(torch.argsort(mag_imp), mag_imp_raw.mean(0))


if __name__=='__main__':
test_imp()
1 change: 1 addition & 0 deletions torch_pruning/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self, stride=1, reverse=False):

def __call__(self, idxs: _HybridIndex):
new_idxs = []

if self.reverse == True:
for i in idxs:
new_idxs.append( _HybridIndex( idx = (i.idx // self._stride), root_idx=i.root_idx ) )
Expand Down
20 changes: 10 additions & 10 deletions torch_pruning/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def get_pruning_group(
module: nn.Module,
pruning_fn: typing.Callable,
idxs: typing.Sequence[int],
return_root_idxs: bool = False,
) -> Group:
"""Get the pruning group of pruning_fn.
Args:
Expand Down Expand Up @@ -470,7 +469,7 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
for mapping in new_dep.index_mapping:
if mapping is not None:
new_indices = mapping(new_indices)
#print(new_dep, new_dep.index_mapping)

#print(len(new_indices))
#print()
if len(new_indices) == 0:
Expand All @@ -493,14 +492,17 @@ def _fix_dependency_graph_non_recursive(dep, idxs, *args):
merged_group.add_and_merge(dep, idxs)
merged_group._DG = self
for i in range(len(merged_group)):
idxs = _helpers.to_plain_idxs(merged_group[i].idxs)
hybrid_idxs = merged_group[i].idxs
idxs = _helpers.to_plain_idxs(hybrid_idxs)
root_idxs = _helpers.to_root_idxs(hybrid_idxs)
merged_group[i] = GroupItem(merged_group[i].dep, idxs) # transform _HybridIndex to plain index
merged_group[i].root_idxs = _helpers.to_root_idxs(merged_group[i].idxs) # add root_idxs
merged_group[i].root_idxs = root_idxs
return merged_group

def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, ops.TORCH_LINEAR)):
visited_layers = []
ignored_layers = ignored_layers+self.IGNORED_LAYERS

for m in list(self.module2node.keys()):
if m in ignored_layers:
continue
Expand All @@ -518,6 +520,7 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o
layer_channels = pruner.get_out_channels(m)
group = self.get_pruning_group(
m, pruner.prune_out_channels, list(range(layer_channels)))

prunable_group = True
for dep, _ in group:
module = dep.target.module
Expand All @@ -531,8 +534,6 @@ def get_all_groups(self, ignored_layers=[], root_module_types=(ops.TORCH_CONV, o

def get_pruner_of_module(self, module: nn.Module):
p = self.CUSTOMIZED_PRUNERS.get(module.__class__, None) # customized pruners for a specific layer type
if p is None:
p = self.CUSTOMIZED_PRUNERS.get(module, None) # customized pruners for a specific layer instance
if p is None:
p = self.REGISTERED_PRUNERS.get(ops.module2type(module), None) # standard pruners
return p
Expand Down Expand Up @@ -696,12 +697,11 @@ def _record_grad_fn(module, inputs, outputs):

# Register hooks for prunable modules
registered_types = tuple(ops.type2class(
t) for t in self.REGISTERED_PRUNERS.keys()) + tuple(t for t in self.CUSTOMIZED_PRUNERS.keys() if not isinstance(t, torch.nn.Module)) # standard pruners + customized pruners for a specific layer type
registered_instances = tuple(instance for instance in self.CUSTOMIZED_PRUNERS.keys() if isinstance(instance, torch.nn.Module)) # customized pruners for a specific layer instance
t) for t in self.REGISTERED_PRUNERS.keys()) + tuple(self.CUSTOMIZED_PRUNERS.keys())
hooks = [
m.register_forward_hook(_record_grad_fn)
for m in model.modules()
if ( (m not in self.IGNORED_LAYERS) and (isinstance(m, registered_types) or (m in registered_instances) ) )
if (isinstance(m, registered_types) and m not in self.IGNORED_LAYERS)
]

# Feed forward to record gradient functions of prunable modules
Expand Down Expand Up @@ -794,7 +794,7 @@ def create_node_if_not_exists(grad_fn):
name=self._module2name.get(module, None),
)
if (
type(module) in self.CUSTOMIZED_PRUNERS or module in self.CUSTOMIZED_PRUNERS
type(module) in self.CUSTOMIZED_PRUNERS
): # mark it as a customized layer
node.type = ops.OPTYPE.CUSTOMIZED
module2node[module] = node
Expand Down
14 changes: 10 additions & 4 deletions torch_pruning/importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,22 @@ def _normalize(self, group_importance, normalizer):

def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]):
if len(group_imp) == 0: return group_imp
reduced_imp = torch.zeros_like(group_imp[0])
if self.group_reduction == 'prod':
reduced_imp = torch.ones_like(group_imp[0])
elif self.group_reduction == 'max':
reduced_imp = torch.ones_like(group_imp[0]) * -99999
else:
reduced_imp = torch.zeros_like(group_imp[0])

for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
if self.group_reduction == "sum" or self.group_reduction == "mean":
reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
elif self.group_reduction == "max": # keep the max importance
selected_imp = torch.select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
torch.max(selected_imp, imp, out=selected_imp)
selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
selected_imp = torch.maximum(input=selected_imp, other=imp)
reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
elif self.group_reduction == "prod": # product of importance
selected_imp = torch.select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
torch.mul(selected_imp, imp, out=selected_imp)
reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
elif self.group_reduction == 'first':
Expand All @@ -82,6 +87,7 @@ def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[
reduced_imp = torch.stack(group_imp, dim=0) # no reduction
else:
raise NotImplementedError

if self.group_reduction == "mean":
reduced_imp /= len(group_imp)
return reduced_imp
Expand Down
2 changes: 2 additions & 0 deletions torch_pruning/pruner/algorithms/metapruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def prune_local(self) -> typing.Generator:
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
# check pruning rate
if self._check_sparsity(group):

module = group[0][0].target.module
pruning_fn = group[0][0].handler

Expand Down Expand Up @@ -232,6 +233,7 @@ def prune_local(self) -> typing.Generator:
[pruning_idxs+group_size*i for i in range(ch_groups)], 0)
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs.tolist())

if self.DG.check_pruning_group(group):
yield group

Expand Down

0 comments on commit 5370954

Please sign in to comment.