Skip to content

Commit

Permalink
Feat (Channel-Splitting): adds split_input option
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Dec 12, 2023
1 parent d31762a commit 0abdccf
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
35 changes: 19 additions & 16 deletions src/brevitas/ptq_algorithms/channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ def _channels_maxabs(module, splits_per_layer, split_input):
# dim 0 -> input channels max, dim 1 -> output channels max!
if len(module.weight.data.shape) > 1:
if isinstance(module, nn.Conv2d):
# TODO make sure when splitting input channel this will return input channels
max_per_channel = module.weight.abs().flatten(dim).max(dim).values
if not split_input:
# gets the max value for each output channel
max_per_channel = module.weight.data.abs().flatten(1).max(1).values
# check if max_per_channel has the same length as output channels
assert len(max_per_channel) == module.weight.shape[0]
else:
# getting max value for each input channel
max_per_channel = module.weight.data.abs().max(0).values.flatten(1).max(1).values
# check if same length as input channels
assert len(max_per_channel) == module.weight.shape[1]
elif isinstance(module, nn.Linear):
max_per_channel = module.weight.data.abs().max(dim=dim).values.flatten()
channels = torch.argsort(max_per_channel, descending=True)
Expand All @@ -37,15 +45,11 @@ def _channels_to_split(
# the modules are all of the same shape so we can just take the first one
num_channels = modules[0].weight.shape[int(split_input)]
total_splits = int(math.ceil(split_ratio * num_channels))
# each channel in the sources only splits a portion of the total splits
if not split_input:
splits_per_layer = int(math.floor(total_splits / len(modules)))
else:
# if we split input channels, each module has to split the whole budget
splits_per_layer = total_splits
# each channel in the modules selects their portion of the total channels to split
splits_per_layer = int(math.floor(total_splits / len(modules)))

if splits_per_layer == 0:
warnings.warn(f'No splits for {sources}, increasing split_ratio could help.')
warnings.warn(f'No splits for {modules}, increasing split_ratio could help.')

module_to_channels = {}
if split_criterion == 'maxabs':
Expand Down Expand Up @@ -135,13 +139,12 @@ def _split_channels_region(
_split_channels(
module, channels_to_split, grid_aware=False, split_factor=1, split_input=True)
else:
# what if we split input channels of the sinks, which channels of the OC srcs have to duplicated?
for module in sources:
_split_channels(module, channels_to_split, grid_aware=grid_aware)
# TODO duplicating the channels in the output channels of the sources could be tricky
# input channels are split in half, output channels duplicated
for module in sinks:
_split_channels(module, channels_to_split, grid_aware=grid_aware, split_input=True)
for module in sources:
# then duplicate the input_channels for all modules in the sink
_split_channels(module, channels_to_split)
# duplicate out_channels for all modules in the source
_split_channels(module, channels_to_split, grid_aware=False, split_factor=1)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
Expand All @@ -150,7 +153,7 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
if len(srcs_ocs) > 1:
return False

# check if ICs of sinks are all equal, what if sinks does not have IC?
# check if ICs of sinks are all equal
sinks_ics = set(module.weight.shape[1] for module in sinks)
if len(sinks_ics) > 1:
return False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,11 @@
help='Apply Channel Splitting before Quantization (default: disabled)')
add_bool_arg(
parser, 'grid-aware', default=False, help='Grid-aware channel splitting (default: disabled)')
add_bool_arg(
parser,
'split-input',
default=False,
help='Input Channels Splitting for channel splitting (default: disabled)')


def main():
Expand Down
10 changes: 6 additions & 4 deletions tests/brevitas/graph/test_channel_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


@pytest.mark.parametrize('split_ratio', [0.05, 0.1, 0.2])
def test_resnet18(split_ratio):
@pytest.mark.parametrize('split_input', [False, True])
def test_resnet18(split_ratio, split_input):
model = models.resnet18(pretrained=True)

torch.manual_seed(SEED)
Expand All @@ -22,13 +23,14 @@ def test_resnet18(split_ratio):
# merge BN before applying channel splitting
model = MergeBatchNorm().apply(model)

model = ChannelSplitting(split_ratio=split_ratio).apply(model)
model = ChannelSplitting(split_ratio=split_ratio, split_input=split_input).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)


@pytest.mark.parametrize('split_ratio', [0.05, 0.1])
def test_alexnet(split_ratio):
@pytest.mark.parametrize('split_input', [False, True])
def test_alexnet(split_ratio, split_input):
model = models.alexnet(pretrained=True)

torch.manual_seed(SEED)
Expand All @@ -39,6 +41,6 @@ def test_alexnet(split_ratio):
model = symbolic_trace(model)

# set split_ratio to 0.2 to def have some splits
model = ChannelSplitting(split_ratio=split_ratio).apply(model)
model = ChannelSplitting(split_ratio=split_ratio, split_input=split_input).apply(model)
out = model(inp)
assert torch.allclose(expected_out, out, atol=ATOL)

0 comments on commit 0abdccf

Please sign in to comment.