Skip to content

Commit

Permalink
Updating brevitas_examples
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Jan 19, 2024
1 parent 3dd7e43 commit 85871f0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
23 changes: 14 additions & 9 deletions src/brevitas_examples/super_resolution/utils/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch
from torch import Tensor
import torch.nn as nn

import brevitas.nn as qnn
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.nn.utils import calculate_min_accumulator_bit_width
from brevitas_examples.super_resolution.models.espcn import QuantESPCN

EPS = 1e-10


def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor:
assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d."

# bit-width and sign need to come from the quant tensor of the preceding layer if no io_quant
input_bit_width = module.quant_input_bit_width()
input_is_signed = module.is_quant_input_signed
input_is_signed = float(module.is_quant_input_signed)

# the tensor quantizer requires a QuantTensor with specified bit-width and sign
quant_weight = module.quant_weight()
quant_weight = quant_weight.int().float()
if isinstance(module,
qnn.QuantConv2d): # shape = (out_channels, in_channels, kernel_size, kernel_size)
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3))
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3))

# using the closed-form bounds on accumulator bit-width
cur_acc_bit_width = calculate_min_accumulator_bit_width(
input_bit_width, input_is_signed, quant_weight_per_channel_l1_norm.max())
return cur_acc_bit_width
weight_max_l1_norm = quant_weight_per_channel_l1_norm.max()
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def evaluate_accumulator_bit_widths(model: nn.Module, inp: Tensor):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/super_resolution/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output):
(weights, input_bit_width, input_is_signed) = inp
s = module.scaling_impl(weights) # s
g = abs_binary_sign_grad(module.restrict_clamp_scaling(module.value)) # g
T = module.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s
T = module.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s
cur_penalty = torch.relu(g - (T * s)).sum()
reg_penalty += cur_penalty
return output
Expand Down

0 comments on commit 85871f0

Please sign in to comment.