From 445120c08f2bebe87cf50d0ffdf815a73382063c Mon Sep 17 00:00:00 2001 From: Fabian Grob Date: Wed, 6 Dec 2023 15:24:54 +0000 Subject: [PATCH] Feat (Channel-Splitting): adds grid_aware option --- src/brevitas/graph/quantize.py | 2 + .../ptq_algorithms/channel_splitting.py | 101 ++++++++---------- .../ptq/ptq_evaluate.py | 7 +- .../brevitas/graph/test_channel_splitting.py | 10 +- 4 files changed, 59 insertions(+), 61 deletions(-) diff --git a/src/brevitas/graph/quantize.py b/src/brevitas/graph/quantize.py index 5acac7609..9545e466f 100644 --- a/src/brevitas/graph/quantize.py +++ b/src/brevitas/graph/quantize.py @@ -267,6 +267,7 @@ def preprocess_for_quantize( equalize_scale_computation: str = 'maxabs', channel_splitting=False, channel_splitting_ratio=0.02, + channel_splitting_grid_aware=False, channel_splitting_criterion: str = 'maxabs'): training_state = model.training @@ -292,6 +293,7 @@ def preprocess_for_quantize( if channel_splitting: model = ChannelSplitting( split_ratio=channel_splitting_ratio, + grid_aware=channel_splitting_grid_aware, split_criterion=channel_splitting_criterion).apply(model) model.train(training_state) return model diff --git a/src/brevitas/ptq_algorithms/channel_splitting.py b/src/brevitas/ptq_algorithms/channel_splitting.py index 70c193393..49ec4ea17 100644 --- a/src/brevitas/ptq_algorithms/channel_splitting.py +++ b/src/brevitas/ptq_algorithms/channel_splitting.py @@ -1,5 +1,6 @@ import math from typing import Dict, List, Set, Tuple, Union +import warnings import torch import torch.nn as nn @@ -8,8 +9,6 @@ from brevitas.graph.base import GraphTransform from brevitas.graph.equalize import _extract_regions -_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) - def _channels_maxabs(module, splits_per_layer, split_input): # works for Conv2d and Linear @@ -44,15 +43,31 @@ def _channels_to_split( else: # if we split input channels, each module has to split the whole budget splits_per_layer = total_splits - assert splits_per_layer > 0, f"No channels to split in {modules} with split_rati {split_ratio}!" + + if splits_per_layer == 0: + warnings.warn(f'No splits for {sources}, increasing split_ratio could help.') module_to_channels = {} if split_criterion == 'maxabs': for module in modules: module_to_channels[module] = _channels_maxabs(module, splits_per_layer, split_input) - # return dict with modules as key and channels to split as value - return module_to_channels + # return tensor with the indices to split + channels_to_split = torch.cat(list(module_to_channels.values())) + return torch.unique(channels_to_split) + + +def _split_single_channel(channel, grid_aware: bool, split_factor: float): + if split_factor == 1: + # duplicates the channel + return channel, channel + + if grid_aware: + slice1 = channel - 0.5 + slice2 = channel + 0.5 + return slice1 * split_factor, slice2 * split_factor + else: + return channel * split_factor, channel * split_factor def _split_channels( @@ -63,96 +78,70 @@ def _split_channels( Can also be used to duplicate a channel, just set split_factor to 1. Returns: None """ - # change it to .data attribute weight = module.weight.data bias = module.bias.data if module.bias is not None else None - if isinstance(module, _batch_norm): - running_mean = module.running_mean.data - running_var = module.running_var.data for id in channels_to_split: if isinstance(module, torch.nn.Conv2d): # there are four dimensions: [OC, IC, k, k] if split_input: - channel = weight[:, id:id + 1, :, :] * split_factor - weight = torch.cat( - (weight[:, :id, :, :], channel, channel, weight[:, id + 1:, :, :]), dim=1) + channel = weight[:, id:id + 1, :, :] + slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor) + weight = torch.cat((weight[:, :id, :, :], slice1, slice2, weight[:, id + 1:, :, :]), + dim=1) module.in_channels += 1 else: - # split output - channel = weight[id:id + 1, :, :, :] * split_factor - # duplicate channel - weight = torch.cat( - (weight[:id, :, :, :], channel, channel, weight[id + 1:, :, :, :]), dim=0) + channel = weight[id:id + 1, :, :, :] + slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor) + weight = torch.cat((weight[:id, :, :, :], slice1, slice2, weight[id + 1:, :, :, :]), + dim=0) module.out_channels += 1 elif isinstance(module, torch.nn.Linear): # there are two dimensions: [OC, IC] if split_input: - # simply duplicate channel - channel = weight[:, id:id + 1] * split_factor - weight = torch.cat((weight[:, :id], channel, channel, weight[:, id + 1:]), dim=1) + channel = weight[:, id:id + 1] + slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor) + weight = torch.cat((weight[:, :id], slice1, slice2, weight[:, id + 1:]), dim=1) module.in_features += 1 else: - # split output - channel = weight[id:id + 1, :] * split_factor - weight = torch.cat((weight[:id, :], channel, channel, weight[id + 1:, :]), dim=0) + channel = weight[id:id + 1, :] + slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor) + weight = torch.cat((weight[:id, :], slice1, slice2, weight[id + 1:, :]), dim=0) module.out_features += 1 - elif isinstance(module, _batch_norm): - # bach norm is 1d - channel = weight[id:id + 1] * split_factor - weight = torch.cat((weight[:id], channel, channel, weight[id + 1:])) - # also split running_mean and running_var - mean = running_mean[id:id + 1] * split_factor - running_mean = torch.cat((running_mean[:id], mean, mean, running_mean[id + 1:])) - - var = running_var[id:id + 1] * split_factor - running_var = torch.cat((running_var[:id], var, var, running_var[id + 1:])) - - module.num_features += 1 - if bias is not None and not split_input: channel = bias[id:id + 1] * split_factor bias = torch.cat((bias[:id], channel, channel, bias[id + 1:])) - # setting the weights as the new data module.weight.data = weight if bias is not None: module.bias.data = bias - if isinstance(module, _batch_norm): - module.running_mean.data = running_mean - module.running_var.data = running_var def _split_channels_region( sources: List[nn.Module], sinks: List[nn.Module], - modules_to_split: Dict[nn.Module, torch.tensor], + channels_to_split: torch.tensor, split_input: bool, grid_aware: bool = False): - # we are getting a dict[Module, channels to split] # splitting output channels # concat all channels that are split so we can duplicate those in the input channels later if not split_input: - channels = torch.cat(list(modules_to_split.values())) - for module in modules_to_split.keys(): - _split_channels(module, channels, grid_aware=grid_aware) - # get all the channels that we have to duplicate - channels = torch.cat(list(modules_to_split.values())) + for module in sources: + _split_channels(module, channels_to_split, grid_aware=grid_aware) for module in sinks: # then duplicate the input_channels for all modules in the sink - _split_channels(module, channels, grid_aware=False, split_factor=1, split_input=True) + _split_channels( + module, channels_to_split, grid_aware=False, split_factor=1, split_input=True) else: # what if we split input channels of the sinks, which channels of the OC srcs have to duplicated? - for module, channels in modules_to_split.items(): - _split_channels(module, channels, grid_aware=grid_aware) + for module in sources: + _split_channels(module, channels_to_split, grid_aware=grid_aware) # TODO duplicating the channels in the output channels of the sources could be tricky - channels_to_duplicate = torch.cat(list(modules_to_split.values())) for module in sources: # then duplicate the input_channels for all modules in the sink - _split_channels( - module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=False) + _split_channels(module, channels_to_split) def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool: @@ -195,9 +184,8 @@ def _split( sinks = [name_to_module[n] for n in region.sinks] if _is_supported(sources, sinks): - # problem: if region[0] has bn modules as sources, we split them but the input from prev layer is not the correct shape anymore! So we need to skip the first region in case that happens # get channels to split - modules_to_split = _channels_to_split( + channels_to_split = _channels_to_split( sources=sources, sinks=sinks, split_criterion=split_criterion, @@ -207,7 +195,8 @@ def _split( _split_channels_region( sources=sources, sinks=sinks, - modules_to_split=modules_to_split, + channels_to_split=channels_to_split, + grid_aware=grid_aware, split_input=split_input) return model diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7a382917b..b1fc3cd40 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -225,6 +225,8 @@ 'channel-splitting', default=False, help='Apply Channel Splitting before Quantization (default: disabled)') +add_bool_arg( + parser, 'grid-aware', default=False, help='Grid-aware channel splitting (default: disabled)') def main(): @@ -290,7 +292,9 @@ def main(): f"Weight quant calibration type: {args.weight_quant_calibration_type} - " f"Calibrate BN: {args.calibrate_bn} - " f"Channel Splitting: {args.channel_splitting} - " - f"Split Ratio: {args.split_ratio} - ") + f"Split Ratio: {args.split_ratio} - " + f"Grid Aware: {args.grid_aware} - " + f"Merge BN: {not args.calibrate_bn}") # Get model-specific configurations about input shapes and normalization model_config = get_model_config(args.model_name) @@ -335,6 +339,7 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn, channel_splitting=args.channel_splitting, + channel_splitting_grid_aware=args.grid_aware, channel_splitting_ratio=args.split_ratio) else: raise RuntimeError(f"{args.target_backend} backend not supported.") diff --git a/tests/brevitas/graph/test_channel_splitting.py b/tests/brevitas/graph/test_channel_splitting.py index 106618c1a..f2571dcfe 100644 --- a/tests/brevitas/graph/test_channel_splitting.py +++ b/tests/brevitas/graph/test_channel_splitting.py @@ -8,7 +8,8 @@ from .equalization_fixtures import * -def test_resnet18(): +@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2]) +def test_resnet18(split_ratio): model = models.resnet18(pretrained=True) torch.manual_seed(SEED) @@ -21,12 +22,13 @@ def test_resnet18(): # merge BN before applying channel splitting model = MergeBatchNorm().apply(model) - model = ChannelSplitting(split_ratio=0.1).apply(model) + model = ChannelSplitting(split_ratio=split_ratio).apply(model) out = model(inp) assert torch.allclose(expected_out, out, atol=ATOL) -def test_alexnet(): +@pytest.mark.parametrize('split_ratio', [0.05, 0.1]) +def test_alexnet(split_ratio): model = models.alexnet(pretrained=True) torch.manual_seed(SEED) @@ -37,6 +39,6 @@ def test_alexnet(): model = symbolic_trace(model) # set split_ratio to 0.2 to def have some splits - model = ChannelSplitting(split_ratio=0.2).apply(model) + model = ChannelSplitting(split_ratio=split_ratio).apply(model) out = model(inp) assert torch.allclose(expected_out, out, atol=ATOL)