Skip to content

Commit

Permalink
Feat (Channel-Splitting): adds grid_aware option
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 12, 2023
1 parent ae3b1e4 commit d31762a
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 61 deletions.
2 changes: 2 additions & 0 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def preprocess_for_quantize(
equalize_scale_computation: str = 'maxabs',
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_grid_aware=False,
channel_splitting_criterion: str = 'maxabs'):

training_state = model.training
Expand All @@ -292,6 +293,7 @@ def preprocess_for_quantize(
if channel_splitting:
model = ChannelSplitting(
split_ratio=channel_splitting_ratio,
grid_aware=channel_splitting_grid_aware,
split_criterion=channel_splitting_criterion).apply(model)
model.train(training_state)
return model
Expand Down
101 changes: 45 additions & 56 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
from typing import Dict, List, Set, Tuple, Union
import warnings

import torch
import torch.nn as nn
Expand All @@ -8,8 +9,6 @@
from brevitas.graph.base import GraphTransform
from brevitas.graph.equalize import _extract_regions

_batch_norm = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)


def _channels_maxabs(module, splits_per_layer, split_input):
# works for Conv2d and Linear
Expand Down Expand Up @@ -44,15 +43,31 @@ def _channels_to_split(
else:
# if we split input channels, each module has to split the whole budget
splits_per_layer = total_splits
assert splits_per_layer > 0, f"No channels to split in {modules} with split_rati {split_ratio}!"

if splits_per_layer == 0:
warnings.warn(f'No splits for {sources}, increasing split_ratio could help.')

module_to_channels = {}
if split_criterion == 'maxabs':
for module in modules:
module_to_channels[module] = _channels_maxabs(module, splits_per_layer, split_input)

# return dict with modules as key and channels to split as value
return module_to_channels
# return tensor with the indices to split
channels_to_split = torch.cat(list(module_to_channels.values()))
return torch.unique(channels_to_split)


def _split_single_channel(channel, grid_aware: bool, split_factor: float):
if split_factor == 1:
# duplicates the channel
return channel, channel

if grid_aware:
slice1 = channel - 0.5
slice2 = channel + 0.5
return slice1 * split_factor, slice2 * split_factor
else:
return channel * split_factor, channel * split_factor


def _split_channels(
Expand All @@ -63,96 +78,70 @@ def _split_channels(
Can also be used to duplicate a channel, just set split_factor to 1.
Returns: None
"""
# change it to .data attribute
weight = module.weight.data
bias = module.bias.data if module.bias is not None else None
if isinstance(module, _batch_norm):
running_mean = module.running_mean.data
running_var = module.running_var.data

for id in channels_to_split:
if isinstance(module, torch.nn.Conv2d):
# there are four dimensions: [OC, IC, k, k]
if split_input:
channel = weight[:, id:id + 1, :, :] * split_factor
weight = torch.cat(
(weight[:, :id, :, :], channel, channel, weight[:, id + 1:, :, :]), dim=1)
channel = weight[:, id:id + 1, :, :]
slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor)
weight = torch.cat((weight[:, :id, :, :], slice1, slice2, weight[:, id + 1:, :, :]),
dim=1)
module.in_channels += 1
else:
# split output
channel = weight[id:id + 1, :, :, :] * split_factor
# duplicate channel
weight = torch.cat(
(weight[:id, :, :, :], channel, channel, weight[id + 1:, :, :, :]), dim=0)
channel = weight[id:id + 1, :, :, :]
slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor)
weight = torch.cat((weight[:id, :, :, :], slice1, slice2, weight[id + 1:, :, :, :]),
dim=0)
module.out_channels += 1

elif isinstance(module, torch.nn.Linear):
# there are two dimensions: [OC, IC]
if split_input:
# simply duplicate channel
channel = weight[:, id:id + 1] * split_factor
weight = torch.cat((weight[:, :id], channel, channel, weight[:, id + 1:]), dim=1)
channel = weight[:, id:id + 1]
slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor)
weight = torch.cat((weight[:, :id], slice1, slice2, weight[:, id + 1:]), dim=1)
module.in_features += 1
else:
# split output
channel = weight[id:id + 1, :] * split_factor
weight = torch.cat((weight[:id, :], channel, channel, weight[id + 1:, :]), dim=0)
channel = weight[id:id + 1, :]
slice1, slice2 = _split_single_channel(channel, grid_aware, split_factor)
weight = torch.cat((weight[:id, :], slice1, slice2, weight[id + 1:, :]), dim=0)
module.out_features += 1

elif isinstance(module, _batch_norm):
# bach norm is 1d
channel = weight[id:id + 1] * split_factor
weight = torch.cat((weight[:id], channel, channel, weight[id + 1:]))
# also split running_mean and running_var
mean = running_mean[id:id + 1] * split_factor
running_mean = torch.cat((running_mean[:id], mean, mean, running_mean[id + 1:]))

var = running_var[id:id + 1] * split_factor
running_var = torch.cat((running_var[:id], var, var, running_var[id + 1:]))

module.num_features += 1

if bias is not None and not split_input:
channel = bias[id:id + 1] * split_factor
bias = torch.cat((bias[:id], channel, channel, bias[id + 1:]))

# setting the weights as the new data
module.weight.data = weight
if bias is not None:
module.bias.data = bias
if isinstance(module, _batch_norm):
module.running_mean.data = running_mean
module.running_var.data = running_var


def _split_channels_region(
sources: List[nn.Module],
sinks: List[nn.Module],
modules_to_split: Dict[nn.Module, torch.tensor],
channels_to_split: torch.tensor,
split_input: bool,
grid_aware: bool = False):
# we are getting a dict[Module, channels to split]
# splitting output channels
# concat all channels that are split so we can duplicate those in the input channels later
if not split_input:
channels = torch.cat(list(modules_to_split.values()))
for module in modules_to_split.keys():
_split_channels(module, channels, grid_aware=grid_aware)
# get all the channels that we have to duplicate
channels = torch.cat(list(modules_to_split.values()))
for module in sources:
_split_channels(module, channels_to_split, grid_aware=grid_aware)
for module in sinks:
# then duplicate the input_channels for all modules in the sink
_split_channels(module, channels, grid_aware=False, split_factor=1, split_input=True)
_split_channels(
module, channels_to_split, grid_aware=False, split_factor=1, split_input=True)
else:
# what if we split input channels of the sinks, which channels of the OC srcs have to duplicated?
for module, channels in modules_to_split.items():
_split_channels(module, channels, grid_aware=grid_aware)
for module in sources:
_split_channels(module, channels_to_split, grid_aware=grid_aware)
# TODO duplicating the channels in the output channels of the sources could be tricky
channels_to_duplicate = torch.cat(list(modules_to_split.values()))
for module in sources:
# then duplicate the input_channels for all modules in the sink
_split_channels(
module, channels_to_duplicate, grid_aware=False, split_factor=1, split_input=False)
_split_channels(module, channels_to_split)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
Expand Down Expand Up @@ -195,9 +184,8 @@ def _split(
sinks = [name_to_module[n] for n in region.sinks]

if _is_supported(sources, sinks):
# problem: if region[0] has bn modules as sources, we split them but the input from prev layer is not the correct shape anymore! So we need to skip the first region in case that happens
# get channels to split
modules_to_split = _channels_to_split(
channels_to_split = _channels_to_split(
sources=sources,
sinks=sinks,
split_criterion=split_criterion,
Expand All @@ -207,7 +195,8 @@ def _split(
_split_channels_region(
sources=sources,
sinks=sinks,
modules_to_split=modules_to_split,
channels_to_split=channels_to_split,
grid_aware=grid_aware,
split_input=split_input)

return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@
'channel-splitting',
default=False,
help='Apply Channel Splitting before Quantization (default: disabled)')
add_bool_arg(
parser, 'grid-aware', default=False, help='Grid-aware channel splitting (default: disabled)')


def main():
Expand Down Expand Up @@ -294,7 +296,9 @@ def main():
f"Weight quant calibration type: {args.weight_quant_calibration_type} - "
f"Calibrate BN: {args.calibrate_bn} - "
f"Channel Splitting: {args.channel_splitting} - "
f"Split Ratio: {args.split_ratio} - ")
f"Split Ratio: {args.split_ratio} - "
f"Grid Aware: {args.grid_aware} - "
f"Merge BN: {not args.calibrate_bn}")

# Get model-specific configurations about input shapes and normalization
model_config = get_model_config(args.model_name)
Expand Down Expand Up @@ -339,6 +343,7 @@ def main():
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn,
channel_splitting=args.channel_splitting,
channel_splitting_grid_aware=args.grid_aware,
channel_splitting_ratio=args.split_ratio)
else:
raise RuntimeError(f"{args.target_backend} backend not supported.")
Expand Down
10 changes: 6 additions & 4 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from .equalization_fixtures import *


def test_resnet18():
@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2])
def test_resnet18(split_ratio):
model = models.resnet18(pretrained=True)

torch.manual_seed(SEED)
Expand All @@ -21,12 +22,13 @@ def test_resnet18():
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

model = ChannelSplitting(split_ratio=0.1).apply(model)
model = ChannelSplitting(split_ratio=split_ratio).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)


def test_alexnet():
@pytest.mark.parametrize('split_ratio', [0.05, 0.1])
def test_alexnet(split_ratio):
model = models.alexnet(pretrained=True)

torch.manual_seed(SEED)
Expand All @@ -37,6 +39,6 @@ def test_alexnet():
model = symbolic_trace(model)

# set split_ratio to 0.2 to def have some splits
model = ChannelSplitting(split_ratio=0.2).apply(model)
model = ChannelSplitting(split_ratio=split_ratio).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

0 comments on commit d31762a

Please sign in to comment.