diff --git a/src/brevitas_examples/imagenet_classification/a2q/README.md b/src/brevitas_examples/imagenet_classification/a2q/README.md new file mode 100644 index 000000000..dc8b70e4c --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/README.md @@ -0,0 +1,28 @@ +# Integer-Quantized Image Classification Models Trained on CIFAR10 with Brevitas + +This directory contains scripts demonstrating how to train integer-quantized image classification models using accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". +Code is also provided to demonstrate A2Q+ as proposed in our arXiv paper "[A2Q+: Improving Accumulator-Aware Weight Quantization](https://arxiv.org/abs/2401.10432)", where we introduce the zero-centered weight quantizer (i.e., `AccumulatorAwareZeroCenterWeightQuant`) as well as the Euclidean projection-based weight initialization (EP-init). + +## Experiments + +All models are trained on the CIFAR10 dataset. +Input images are normalized to have unit mean and variance. +During training, random cropping is applied, along with random horizontal flips. +All residual connections are quantized to the specified activation bit width. + + +| Model Name | Weight Quantization | Activation Quantization | Target Accumulator | Top-1 Accuracy (%) | +|-----------------------------|----------------|---------------------|-------------------------|----------------------------| +| [float_resnet18](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/float_resnet18-1d98d23a.pth) | float32 | float32 | float32 | 95.0 | +|| +| [quant_resnet18_w4a4_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_16b-d0af41f1.pth) | int4 | uint4 | int16 | 94.2 | +| [quant_resnet18_w4a4_a2q_15b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_15b-0d5bf266.pth) | int4 | uint4 | int15 | 94.2 | +| [quant_resnet18_w4a4_a2q_14b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_14b-267f237b.pth) | int4 | uint4 | int14 | 92.6 | +| [quant_resnet18_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_13b-8c31a2b1.pth) | int4 | uint4 | int13 | 89.8 | +| [quant_resnet18_w4a4_a2q_12b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_12b-8a440436.pth) | int4 | uint4 | int12 | 83.9 | +|| +| [quant_resnet18_w4a4_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_16b-19973380.pth) | int4 | uint4 | int16 | 94.2 | +| [quant_resnet18_w4a4_a2q_plus_15b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_15b-3c89551a.pth) | int4 | uint4 | int15 | 94.1 | +| [quant_resnet18_w4a4_a2q_plus_14b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_14b-5a2d11aa.pth) | int4 | uint4 | int14 | 94.1 | +| [quant_resnet18_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_13b-332aaf81.pth) | int4 | uint4 | int13 | 92.8 | +| [quant_resnet18_w4a4_a2q_plus_12b](https://github.com/Xilinx/brevitas/releases/download/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_12b-d69f003b.pth) | int4 | uint4 | int12 | 90.6 | diff --git a/src/brevitas_examples/imagenet_classification/a2q/__init__.py b/src/brevitas_examples/imagenet_classification/a2q/__init__.py new file mode 100644 index 000000000..d5713fd41 --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/brevitas_examples/imagenet_classification/a2q/a2q_evaluate_models.py b/src/brevitas_examples/imagenet_classification/a2q/a2q_evaluate_models.py new file mode 100644 index 000000000..963eff418 --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/a2q_evaluate_models.py @@ -0,0 +1,107 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +from hashlib import sha256 +import os +import random + +import numpy as np +import torch +import torch.nn as nn + +import brevitas.config as config +from brevitas.export import export_qonnx +import brevitas_examples.imagenet_classification.a2q.utils as utils + +parser = argparse.ArgumentParser() +parser.add_argument( + "--data-root", type=str, required=True, help="Directory where the dataset is stored.") +parser.add_argument( + "--model-name", + type=str, + default="quant_resnet18_w4a4_a2q_32b", + help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'", + choices=utils.model_impl.keys()) +parser.add_argument( + "--save-path", + type=str, + default="outputs/", + help="Directory where to save checkpoints. Default: 'outputs/'") +parser.add_argument( + "--load-from-path", + type=str, + default=None, + help="Optional local path to load torch checkpoint from. Default: None") +parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Number of workers for the dataloader to use. Default: 0") +parser.add_argument( + "--pin-memory", + action="store_true", + default=False, + help="If true, pin memory for the dataloader.") +parser.add_argument( + "--batch-size", type=int, default=512, help="Batch size for the dataloader. Default: 512") +parser.add_argument( + "--save-torch-model", + action="store_true", + default=False, + help="If true, save torch model to specified save path.") +parser.add_argument( + "--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.") + +SEED = 0 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) + +# create a random input for graph tracing +random_inp = torch.randn(1, 3, 32, 32) + +if __name__ == "__main__": + + args = parser.parse_args() + + config.JIT_ENABLED = not args.export_to_qonnx + + # Initialize dataloaders + print(f"Loading CIFAR10 dataset from {args.data_root}...") + trainloader, testloader = utils.get_cifar10_dataloaders( + data_root=args.data_root, + batch_size_train=args.batch_size, # does not matter here + batch_size_test=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_memory) + + # if load-from-path is not specified, then use the pre-trained checkpoint + model = utils.get_model_by_name(args.model_name, pretrained=args.load_from_path is None) + if args.load_from_path is not None: + # note that if you used bias correction, you may need to prepare the model for the + # new biases that were introduced. See `utils.get_model_by_name` for more details. + state_dict = torch.load(args.load_from_path, map_location="cpu") + model.load_state_dict(state_dict) + criterion = nn.CrossEntropyLoss() + + top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) + print(f"Final top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") + + # save checkpoint + os.makedirs(args.save_path, exist_ok=True) + if args.save_torch_model: + ckpt_path = f"{args.save_path}/{args.model_name}.pth" + torch.save(model.state_dict(), ckpt_path) + with open(ckpt_path, "rb") as _file: + bytes = _file.read() + model_tag = sha256(bytes).hexdigest()[:8] + new_ckpt_path = f"{args.save_path}/{args.model_name}-{model_tag}.pth" + os.rename(ckpt_path, new_ckpt_path) + print(f"Saved model checkpoint to: {new_ckpt_path}") + + if args.export_to_qonnx: + export_qonnx( + model.cpu(), + input_t=random_inp.cpu(), + export_path=f"{args.save_path}/{args.model_name}-{model_tag}.onnx") diff --git a/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py b/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py new file mode 100644 index 000000000..97a214c9a --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/a2q_train_models.py @@ -0,0 +1,204 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +import copy +from hashlib import sha256 +import os +import random + +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +import torch.optim.lr_scheduler as lrs + +import brevitas.config as config +from brevitas.export import export_qonnx +import brevitas_examples.imagenet_classification.a2q.utils as utils + +parser = argparse.ArgumentParser() +parser.add_argument( + "--data-root", type=str, required=True, help="Directory where the dataset is stored.") +parser.add_argument( + "--model-name", + type=str, + default="quant_resnet18_w4a4_a2q_32b", + help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'", + choices=utils.model_impl.keys()) +parser.add_argument( + "--save-path", + type=str, + default="outputs/", + help="Directory where to save checkpoints. Default: 'outputs/'") +parser.add_argument( + "--num-workers", + type=int, + default=0, + help="Number of workers for the dataloader to use. Default: 0") +parser.add_argument( + "--pin-memory", + action="store_true", + default=False, + help="If true, pin memory for the dataloader.") +parser.add_argument( + "--batch-size-train", + type=int, + default=256, + help="Batch size for the training dataloader. Default: 256") +parser.add_argument( + "--batch-size-test", + type=int, + default=512, + help="Batch size for the testing dataloader. Default: 512") +parser.add_argument( + "--batch-size-calibration", + type=int, + default=256, + help="Batch size for the calibration dataloader. Default: 256") +parser.add_argument( + "--calibration-samples", + type=int, + default=1000, + help="Number of samples to use for calibration. Default: 1000") +parser.add_argument( + "--weight-decay", + type=float, + default=1e-5, + help="Weight decay for the Adam optimizer. Default: 0.00001") +parser.add_argument( + "--lr-init", type=float, default=1e-3, help="Initial learning rate. Default: 0.001") +parser.add_argument( + "--lr-step-size", + type=int, + default=30, + help="Step size for the learning rate scheduler. Default: 30") +parser.add_argument( + "--lr-gamma", + type=float, + default=0.1, + help="Default gamma for the learning rate scheduler. Default: 0.1") +parser.add_argument( + "--total-epochs", type=int, default=90, help="Total epoch to train the model for. Default: 90") +parser.add_argument( + "--from-float-checkpoint", + action="store_true", + default=False, + help="If true, use a pre-trained floating-point checkpoint.") +parser.add_argument( + "--save-torch-model", + action="store_true", + default=False, + help="If true, save torch model to specified save path.") +parser.add_argument( + "--apply-act-calibration", + action="store_true", + default=False, + help="If true, apply activation calibration to the quantized model.") +parser.add_argument( + "--apply-bias-correction", + action="store_true", + default=False, + help="If true, apply bias correction to the quantized model.") +parser.add_argument( + "--apply-ep-init", + action="store_true", + default=False, + help="If true, apply EP-init to the quantized model.") +parser.add_argument( + "--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.") + +# ignore missing keys when loading pre-trained checkpoint +config.IGNORE_MISSING_KEYS = True + +SEED = 0 +random.seed(SEED) +np.random.seed(SEED) +torch.manual_seed(SEED) + +# create a random input for graph tracing +random_inp = torch.randn(1, 3, 32, 32) + +if __name__ == "__main__": + + args = parser.parse_args() + + config.JIT_ENABLED = not args.export_to_qonnx + + # Initialize dataloaders + print(f"Loading CIFAR10 dataset from {args.data_root}...") + trainloader, testloader = utils.get_cifar10_dataloaders( + data_root=args.data_root, + batch_size_train=args.batch_size_train, + batch_size_test=args.batch_size_test, + num_workers=args.num_workers, + pin_memory=args.pin_memory) + calibloader = utils.create_calibration_dataloader( + dataset=trainloader.dataset, + batch_size=args.batch_size_calibration, + num_workers=args.num_workers, + subset_size=args.calibration_samples) + + model = utils.get_model_by_name( + args.model_name, init_from_float_checkpoint=args.from_float_checkpoint) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD( + utils.filter_params(model.named_parameters(), args.weight_decay), + lr=args.lr_init, + weight_decay=args.weight_decay) + scheduler = lrs.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + + # Calibrate the quant model on the calibration dataset + if args.apply_ep_init: + print("Applying EP-init:") + model = utils.apply_ep_init(model, random_inp) + + # Calibrate the quant model on the calibration dataset + if args.apply_act_calibration: + print("Applying activation calibration:") + utils.apply_act_calibrate(calibloader, model) + + if args.apply_bias_correction: + print("Applying bias correction:") + utils.apply_bias_correction(calibloader, model) + + best_top_1, best_weights = 0., copy.deepcopy(model.state_dict()) + for epoch in range(args.total_epochs): + + train_loss = utils.train_for_epoch(trainloader, model, criterion, optimizer) + test_top_1, test_top_5, test_loss = utils.evaluate_topk_accuracies(testloader, model, criterion) + scheduler.step() + + print( + f"[Epoch {epoch:03d}]", + f"train_loss={train_loss:.3f},", + f"test_loss={test_loss:.3f},", + f"test_top_1={test_top_1:.1%},", + f"test_top_5={test_top_5:.1%}", + sep=" ") + + if test_top_1 >= best_top_1: + best_weights = copy.deepcopy(model.state_dict()) + best_top_1 = test_top_1 + + model.load_state_dict(best_weights) + top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) + print(f"Final: top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") + + # save checkpoint + os.makedirs(args.save_path, exist_ok=True) + if args.save_torch_model: + ckpt_path = f"{args.save_path}/{args.model_name}.pth" + torch.save(best_weights, ckpt_path) + with open(ckpt_path, "rb") as _file: + bytes = _file.read() + model_tag = sha256(bytes).hexdigest()[:8] + new_ckpt_path = f"{args.save_path}/{args.model_name}-{model_tag}.pth" + os.rename(ckpt_path, new_ckpt_path) + print(f"Saved model checkpoint to {new_ckpt_path}") + + if args.export_to_qonnx: + export_qonnx( + model.cpu(), + input_t=random_inp.cpu(), + export_path=f"{args.save_path}/{args.model_name}-{model_tag}.onnx") diff --git a/src/brevitas_examples/imagenet_classification/a2q/ep_init.py b/src/brevitas_examples/imagenet_classification/a2q/ep_init.py new file mode 100644 index 000000000..026f78af8 --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/ep_init.py @@ -0,0 +1,131 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import partial + +import numpy as np +import torch +from torch import Tensor +import torch.nn as nn + +from brevitas.core.scaling import AccumulatorAwareParameterPreScaling +from brevitas.function.shape import over_output_channels +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL + +__all__ = ["apply_bias_correction", "apply_ep_init"] + + +def get_a2q_module(module: nn.Module): + for submod in module.modules(): + if isinstance(submod, AccumulatorAwareParameterPreScaling): + return submod + return None + + +def _euclidean_projection_onto_positive_simplex(vec: Tensor, radius: float = 1.): + assert radius > 0, "Error: radius needs to be strictly positive." + assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix." + assert vec.min() >= 0, "Error: assuming a vector of non-negative numbers." + n_elems = vec.shape[0] + # if we are already within the simplex, then the best projection is itself + if vec.sum() <= radius: + return vec + # using algorithm derived in `Efficient Projections onto the L1-Ball for + # Learning in High Dimensions` + v = vec.cpu().detach().numpy() + u = np.sort(v)[::-1] + cumsum_u = np.cumsum(u) + rho = np.nonzero(u * np.arange(1, n_elems + 1) > (cumsum_u - radius))[0][-1] + theta = float(cumsum_u[rho] - radius) / (rho + 1) + w = np.clip(v - theta, 0, np.inf) + vec.data = torch.tensor(w, dtype=vec.dtype, device=vec.device) + return vec + + +def euclidean_projection_onto_l1_ball(vec: Tensor, radius: float): + assert radius > 0, "Error: radius needs to be strictly positive." + assert vec.ndim == 1, "Error: projection assumes a vector, not a matrix." + vec_dir = vec.sign() + vec_mag = _euclidean_projection_onto_positive_simplex(vec.abs(), radius) + new_vec = vec_dir * vec_mag + assert vec.shape == new_vec.shape, "Error: shape changed." + return new_vec + + +def l1_proj_matrix_per_channel(weights: Tensor, radius: Tensor): + assert isinstance(weights, Tensor), "Error: weights is assumed to be a Tensor." + assert isinstance(radius, Tensor), "Error: radius is assumed to be a Tensor." + assert weights.ndim == 2, "Error: assuming a matrix with ndim=2." + # if defined per-tensor + if radius.ndim == 0: + radius = torch.ones(weights.shape[0]) * radius + # if defined per-channel + else: + radius = radius.flatten() + assert radius.nelement() == weights.shape[0], "Error: shape mismatch." + # project each channel independently + for i in range(weights.shape[0]): + w = weights[i] + z = radius[i].item() + v = euclidean_projection_onto_l1_ball(w, z) + weights[i] = v + return weights + + +def apply_ep_init(model: nn.Module, inp: Tensor): + """Euclidean projection-based weight initialization (EP-init) for accumulator-aware + quantization as proposed in `A2Q+: Improving Accumulator-Aware Weight Quantization`""" + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + + module_stats = {} + hook_list = list() + + def register_upper_bound(module: AccumulatorAwareParameterPreScaling, inp, output, name): + """Accumulate the regularization penalty across constrained layers""" + nonlocal module_stats + + (weights, input_bit_width, input_is_signed) = inp + scales: Tensor = module.scaling_impl(weights) + max_norm: Tensor = module.calc_max_l1_norm(input_bit_width, input_is_signed) + + shape = over_output_channels(weights) + s = scales.reshape(shape) + w = weights.reshape(shape) + + z: Tensor = s * max_norm # radius + module_stats[name] = (w.detach(), z.detach()) # no gradients + + restrict_value_impl = module.restrict_clamp_scaling.restrict_value_impl + pre_scaling_init: Tensor = restrict_value_impl.restrict_init_tensor(scales * max_norm) + assert pre_scaling_init.shape == module.value.shape, "Error: shape mismatch." + module.value.data = torch.where( + module.value.data <= pre_scaling_init, module.value.data, pre_scaling_init) + + return output + + # add hooks to each of the A2Q pre-scaling modules + for name, mod in model.named_modules(): + if isinstance(mod, QuantWBIOL): + submod = get_a2q_module(mod) + if submod is not None: + hook_fn = partial(register_upper_bound, name=name) + hook = submod.register_forward_hook(hook_fn) + hook_list.append(hook) + + inp = inp.to(device=device, dtype=dtype) + model(inp) # register the scaled upper bounds + + # project weights onto the l1-ball + for name, mod in model.named_modules(): + if name in module_stats and isinstance(mod, (nn.Conv2d, nn.Linear)): + (weights, radius) = module_stats[name] + weights = l1_proj_matrix_per_channel(weights, radius) + weights = weights.reshape(mod.weight.shape) + mod.weight.data = weights + + for hook in hook_list: + hook.remove() + + return model diff --git a/src/brevitas_examples/imagenet_classification/a2q/quant.py b/src/brevitas_examples/imagenet_classification/a2q/quant.py new file mode 100644 index 000000000..feff9e6da --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/quant.py @@ -0,0 +1,23 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from brevitas.quant import Int8AccumulatorAwareWeightQuant +from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant + +__all__ = ["CommonIntAccumulatorAwareWeightQuant", "CommonIntAccumulatorAwareZeroCenterWeightQuant"] + +SCALING_MIN_VAL = 1e-8 + + +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + """A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance""" + bit_width = None + scaling_min_val = SCALING_MIN_VAL + pre_scaling_min_val = SCALING_MIN_VAL + + +class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant): + """A2Q+: Improving Accumulator-Aware Weight Quantization""" + bit_width = None + scaling_min_val = SCALING_MIN_VAL + pre_scaling_min_val = SCALING_MIN_VAL diff --git a/src/brevitas_examples/imagenet_classification/a2q/resnet.py b/src/brevitas_examples/imagenet_classification/a2q/resnet.py new file mode 100644 index 000000000..1ef7d40fc --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/resnet.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, Advanced Micro Devices, Inc. +# Copyright (c) 2017, liukuang +# All rights reserved. +# SPDX-License-Identifier: MIT + +import torch.nn as nn +import torch.nn.functional as F + +from brevitas.nn.quant_layer import WeightQuantType +from brevitas.quant import Int8WeightPerChannelFloat +from brevitas_examples.bnn_pynq.models.resnet import QuantBasicBlock +from brevitas_examples.bnn_pynq.models.resnet import QuantResNet + + +def weight_init(layer): + if isinstance(layer, nn.Conv2d): + nn.init.kaiming_normal_(layer.weight, nn.init.calculate_gain('relu')) + if layer.bias is not None: + layer.bias.data.zero_() + elif isinstance(layer, nn.BatchNorm2d): + nn.init.constant_(layer.weight, 1) + nn.init.constant_(layer.bias, 0) + + +class BasicBlock(nn.Module): + """Basic block architecture modified for CIFAR10. + Adapted from https://github.com/kuangliu/pytorch-cifar""" + expansion = 1 + + def __init__(self, in_planes: int, planes: int, stride: int = 1): + super(BasicBlock, self).__init__() + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.downsample = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + # using a convolution shortcut rather than identity + self.downsample = nn.Sequential( + nn.Conv2d( + in_planes, + self.expansion * planes, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(self.expansion * planes), + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.downsample(x) + out = F.relu(out) + return out + + +class ResNet(nn.Module): + """ ResNet architecture modified for CIFAR10. + Adapted from https://github.com/kuangliu/pytorch-cifar""" + + def __init__(self, block_impl, num_blocks, num_classes: int = 10): + super(ResNet, self).__init__() + + # stride and padding of 1 with kernel size of 3, compared to ImageNet model + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + + self.in_planes = 64 + self.layer1 = self.create_block(block_impl, 64, num_blocks[0], stride=1) + self.layer2 = self.create_block(block_impl, 128, num_blocks[1], stride=2) + self.layer3 = self.create_block(block_impl, 256, num_blocks[2], stride=2) + self.layer4 = self.create_block(block_impl, 512, num_blocks[3], stride=2) + self.linear = nn.Linear(512 * block_impl.expansion, num_classes) + + self.apply(weight_init) + + def create_block(self, block_impl, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for stride in strides: + layers.append(block_impl(self.in_planes, planes, stride)) + self.in_planes = planes * block_impl.expansion # expand input planes + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + out = self.linear(out) + return out + + +def float_resnet18(num_classes: int = 10) -> ResNet: + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes) + return model + + +def quant_resnet18( + num_classes: int = 10, + act_bit_width: int = 8, + acc_bit_width: int = 32, + weight_bit_width: int = 8, + weight_quant: WeightQuantType = Int8WeightPerChannelFloat) -> QuantResNet: + weight_quant = weight_quant.let(accumulator_bit_width=acc_bit_width) + model = QuantResNet( + block_impl=QuantBasicBlock, + num_blocks=[2, 2, 2, 2], + num_classes=num_classes, + act_bit_width=act_bit_width, + weight_bit_width=weight_bit_width, + weight_quant=weight_quant, + last_layer_weight_quant=Int8WeightPerChannelFloat, + first_maxpool=False, + zero_init_residual=False, + round_average_pool=False) + return model diff --git a/src/brevitas_examples/imagenet_classification/a2q/utils.py b/src/brevitas_examples/imagenet_classification/a2q/utils.py new file mode 100644 index 000000000..5546fa459 --- /dev/null +++ b/src/brevitas_examples/imagenet_classification/a2q/utils.py @@ -0,0 +1,350 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from functools import partial +from typing import Tuple, Type + +import numpy as np +import torch +from torch import hub +from torch import Tensor +from torch.nn import Module +import torch.nn as nn +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +from torch.utils.data import Subset +import torchvision +import torchvision.transforms as transforms +from tqdm import tqdm + +from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling +from brevitas.function import abs_binary_sign_grad +from brevitas.graph.calibrate import bias_correction_mode +from brevitas.graph.calibrate import calibration_mode + +from .ep_init import apply_ep_init +from .quant import * +from .resnet import float_resnet18 +from .resnet import quant_resnet18 + +__all__ = [ + "apply_ep_init", + "apply_act_calibrate", + "apply_bias_correction", + "get_model_by_name", + "filter_params", + "create_calibration_dataloader", + "get_cifar10_dataloaders", + "train_for_epoch", + "evaluate_topk_accuracies"] + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + +model_impl = { + "float_resnet18": + float_resnet18, + "quant_resnet18_w4a4_a2q_16b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=16, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareWeightQuant), + "quant_resnet18_w4a4_a2q_15b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=15, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareWeightQuant), + "quant_resnet18_w4a4_a2q_14b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=14, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareWeightQuant), + "quant_resnet18_w4a4_a2q_13b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=13, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareWeightQuant), + "quant_resnet18_w4a4_a2q_12b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=12, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareWeightQuant), + "quant_resnet18_w4a4_a2q_plus_16b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=16, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant), + "quant_resnet18_w4a4_a2q_plus_15b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=15, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant), + "quant_resnet18_w4a4_a2q_plus_14b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=14, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant), + "quant_resnet18_w4a4_a2q_plus_13b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=13, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant), + "quant_resnet18_w4a4_a2q_plus_12b": + partial( + quant_resnet18, + act_bit_width=4, + acc_bit_width=12, + weight_bit_width=4, + weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant)} + +root_url = 'https://github.com/Xilinx/brevitas/releases/download/' + +model_url = { + "float_resnet18": + f"{root_url}/a2q_cifar10_r1/float_resnet18-1d98d23a.pth", + "quant_resnet18_w4a4_a2q_12b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_12b-8a440436.pth", + "quant_resnet18_w4a4_a2q_13b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_13b-8c31a2b1.pth", + "quant_resnet18_w4a4_a2q_14b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_14b-267f237b.pth", + "quant_resnet18_w4a4_a2q_15b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_15b-0d5bf266.pth", + "quant_resnet18_w4a4_a2q_16b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_16b-d0af41f1.pth", + "quant_resnet18_w4a4_a2q_plus_12b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_12b-d69f003b.pth", + "quant_resnet18_w4a4_a2q_plus_13b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_13b-332aaf81.pth", + "quant_resnet18_w4a4_a2q_plus_14b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_14b-5a2d11aa.pth", + "quant_resnet18_w4a4_a2q_plus_15b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_15b-3c89551a.pth", + "quant_resnet18_w4a4_a2q_plus_16b": + f"{root_url}/a2q_cifar10_r1/quant_resnet18_w4a4_a2q_plus_16b-19973380.pth"} + + +def get_model_by_name( + model_name: str, + pretrained: bool = False, + init_from_float_checkpoint: bool = False) -> nn.Module: + + assert model_name in model_impl, f"Error: {model_name} not implemented." + assert not (pretrained and init_from_float_checkpoint), "Error: pretrained and init_from_float_checkpoint cannot both be true." + model: Module = model_impl[model_name]() + + if init_from_float_checkpoint: + checkpoint = model_url["float_resnet18"] + state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') + model.load_state_dict(state_dict, strict=True) + + elif pretrained: + checkpoint = model_url[model_name] + state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') + if model_name.startswith("quant"): + # fixes issue when bias keys are missing in the pre-trained state_dict when loading from checkpoint + _prepare_bias_corrected_quant_model(model) + model.load_state_dict(state_dict, strict=True) + + return model + + +def filter_params(named_params, decay): + decay_params, no_decay_params = [], [] + for name, param in named_params: + # Do not apply weight decay to the bias or any scaling parameters + if 'scaling' in name or name.endswith(".bias"): + no_decay_params.append(param) + else: + decay_params.append(param) + return [{ + 'params': no_decay_params, 'weight_decay': 0.}, { + 'params': decay_params, 'weight_decay': decay}] + + +def create_calibration_dataloader( + dataset: Dataset, batch_size: int, num_workers: int, subset_size: int) -> DataLoader: + + all_indices = np.arange(len(dataset)) + cur_indices = np.random.choice(all_indices, size=subset_size) + subset = Subset(dataset, cur_indices) + loader = DataLoader(subset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) + return loader + + +def get_cifar10_dataloaders( + data_root: str, + batch_size_train: int = 128, + batch_size_test: int = 100, + num_workers: int = 2, + pin_memory: bool = True, + download: bool = False) -> Tuple[Type[DataLoader]]: + + mean, std = [0.491, 0.482, 0.447], [0.247, 0.243, 0.262] + + # create training dataloader + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std)]) + trainset = torchvision.datasets.CIFAR10( + root=data_root, + train=True, + download=download, + transform=transform_train, + ) + trainloader = DataLoader( + trainset, + batch_size=batch_size_train, + shuffle=True, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=num_workers > 0, + ) + + # creating the validation dataloader + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std),]) + testset = torchvision.datasets.CIFAR10( + root=data_root, + train=False, + download=download, + transform=transform_test, + ) + testloader = DataLoader( + testset, + batch_size=batch_size_test, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=num_workers > 0, + ) + + return trainloader, testloader + + +def apply_act_calibrate(calib_loader, model): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with calibration_mode(model): + for images, _ in tqdm(calib_loader): + images = images.to(device) + images = images.to(dtype) + model(images) + + +def apply_bias_correction(calib_loader, model: nn.Module): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + with torch.no_grad(): + with bias_correction_mode(model): + for (images, _) in tqdm(calib_loader): + images = images.to(device) + images = images.to(dtype) + model(images) + + +def _prepare_bias_corrected_quant_model(model: nn.Module): + model.eval() + dtype = next(model.parameters()).dtype + device = next(model.parameters()).device + images = torch.randn(10, 3, 32, 32) + images = images.to(device) + images = images.to(dtype) + with torch.no_grad(): + with bias_correction_mode(model): + model(images) + + +def train_for_epoch(trainloader, model, criterion, optimizer, reg_weight: float = 1e-3): + model.train() + model = model.to(device) + + tot_loss, reg_penalty = 0., 0. + + def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output): + """Accumulate the regularization penalty across constrained layers""" + nonlocal reg_penalty + (weights, input_bit_width, input_is_signed) = inp + s = module.scaling_impl(weights) # s + g = abs_binary_sign_grad(module.restrict_clamp_scaling(module.value)) # g + T = module.calc_max_l1_norm(input_bit_width, input_is_signed) # T / s + cur_penalty = torch.relu(g - (T * s)).sum() + reg_penalty += cur_penalty + return output + + # Register a forward hook to accumulate the regularization penalty + hook_fns = list() + for mod in model.modules(): + if isinstance(mod, AccumulatorAwareParameterPreScaling): + hook = mod.register_forward_hook(acc_reg_penalty) + hook_fns.append(hook) + + progress_bar = tqdm(trainloader) + for _, (images, targets) in enumerate(progress_bar): + optimizer.zero_grad() + images = images.to(device) + targets = targets.to(device) + outputs = model(images) + task_loss: Tensor = criterion(outputs, targets) + loss = task_loss + (reg_weight * reg_penalty) + loss.backward() + optimizer.step() + reg_penalty = 0. # reset the accumulated regularization penalty + tot_loss += task_loss.item() * images.size(0) + + # Remove the registered forward hooks before exiting + for hook in hook_fns: + hook.remove() + + avg_loss = tot_loss / len(trainloader.dataset) + return avg_loss + + +@torch.no_grad() +def evaluate_topk_accuracies(testloader, model, criterion): + model.eval() + model = model.to(device) + + progress_bar = tqdm(testloader) + + top_1, top_5, tot_loss = 0., 0., 0. + for _, (images, targets) in enumerate(progress_bar): + images = images.to(device) + targets = targets.to(device) + outputs: Tensor = model(images) + loss: Tensor = criterion(outputs, targets) + + # Evaluating Top-1 and Top-5 accuracy + _, y_pred = outputs.topk(5, 1, True, True) + y_pred = y_pred.t() + correct = y_pred.eq(targets.view(1, -1).expand_as(y_pred)) + top_1 += correct[0].float().sum().item() + top_5 += correct.float().sum().item() + tot_loss += loss.item() * images.size(0) + top_1 /= len(testloader.dataset) + top_5 /= len(testloader.dataset) + tot_loss /= len(testloader.dataset) + return top_1, top_5, tot_loss diff --git a/tests/brevitas_examples/test_examples_import.py b/tests/brevitas_examples/test_examples_import.py index e3796d03c..db4f3af74 100644 --- a/tests/brevitas_examples/test_examples_import.py +++ b/tests/brevitas_examples/test_examples_import.py @@ -55,3 +55,19 @@ def test_super_resolution_float_and_quant_models_match(upscale_factor, num_chann float_model = float_espcn(upscale_factor, num_channels) quant_model = quant_espcn(upscale_factor, num_channels, weight_quant=weight_quant) quant_model.load_state_dict(float_model.state_dict()) + + +@pytest.mark.parametrize( + "weight_quant", + [ + Int8WeightPerChannelFloat, + Int8AccumulatorAwareWeightQuant, + Int8AccumulatorAwareZeroCenterWeightQuant]) +def test_image_classification_float_and_quant_models_match(weight_quant): + import brevitas.config as config + from brevitas_examples.imagenet_classification.a2q.resnet import float_resnet18 + from brevitas_examples.imagenet_classification.a2q.resnet import quant_resnet18 + config.IGNORE_MISSING_KEYS = True + float_model = float_resnet18(num_classes=10) + quant_model = quant_resnet18(num_classes=10, weight_quant=weight_quant) + quant_model.load_state_dict(float_model.state_dict())