-
Notifications
You must be signed in to change notification settings - Fork 195
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Creating evaluation script and README doc
- Loading branch information
Showing
2 changed files
with
135 additions
and
0 deletions.
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
src/brevitas_examples/imagenet_classification/a2q/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Integer-Quantized Image Classification Models Trained on CIFAR10 with Brevitas | ||
|
||
This directory contains scripts demonstrating how to train integer-quantized image classification models using accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". | ||
Code is also provided to demonstrate A2Q+ as proposed in our arXiv paper "[A2Q+: Improving Accumulator-Aware Weight Quantization](https://arxiv.org/abs/2401.10432)", where we introduce the zero-centered weight quantizer (i.e., `AccumulatorAwareZeroCenterWeightQuant`) as well as the Euclidean projection-based weight initialization (EP-init). | ||
|
||
## Experiments | ||
|
||
All models are trained on the CIFAR10 dataset. | ||
Input images are normalized to have unit mean and variance. | ||
During training, random cropping is applied, along with random horizontal flips. | ||
All residual connections are quantized to the specified activation bit width. | ||
|
||
|
||
| Model Name | Weight Quantization | Activation Quantization | Target Accumulator | Top-1 Accuracy (%) | | ||
|-----------------------------|----------------|---------------------|-------------------------|----------------------------| | ||
| [float_resnet18](https://github.com/Xilinx/brevitas/releases/download/ep_init/float_resnet18-1d98d23a.pth) | float32 | float32 | float32 | 95.0 | | ||
|| | ||
| [quant_resnet18_w4a4_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_16b-d0af41f1.pth) | int4 | uint4 | int16 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_15b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_15b-0d5bf266.pth) | int4 | uint4 | int15 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_14b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_14b-267f237b.pth) | int4 | uint4 | int14 | 92.6 | | ||
| [quant_resnet18_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_13b-8c31a2b1.pth) | int4 | uint4 | int13 | 89.8 | | ||
| [quant_resnet18_w4a4_a2q_12b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_12b-8a440436.pth) | int4 | uint4 | int12 | 83.9 | | ||
|| | ||
| [quant_resnet18_w4a4_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_plus_16b-19973380.pth) | int4 | uint4 | int16 | 94.2 | | ||
| [quant_resnet18_w4a4_a2q_plus_15b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_plus_15b-3c89551a.pth) | int4 | uint4 | int15 | 94.1 | | ||
| [quant_resnet18_w4a4_a2q_plus_14b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_plus_14b-5a2d11aa.pth) | int4 | uint4 | int14 | 94.1 | | ||
| [quant_resnet18_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_plus_13b-332aaf81.pth) | int4 | uint4 | int13 | 92.8 | | ||
| [quant_resnet18_w4a4_a2q_plus_12b](https://github.com/Xilinx/brevitas/releases/download/ep_init/quant_resnet18_w4a4_a2q_plus_12b-d69f003b.pth) | int4 | uint4 | int12 | 90.6 | |
107 changes: 107 additions & 0 deletions
107
src/brevitas_examples/imagenet_classification/a2q/a2q_evaluate_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import argparse | ||
from hashlib import sha256 | ||
import os | ||
import random | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
|
||
import brevitas.config as config | ||
from brevitas.export import export_qonnx | ||
import brevitas_examples.imagenet_classification.a2q.utils as utils | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--data-root", type=str, required=True, help="Directory where the dataset is stored.") | ||
parser.add_argument( | ||
"--model-name", | ||
type=str, | ||
default="quant_resnet18_w4a4_a2q_32b", | ||
help="Name of model to train. Default: 'quant_resnet18_w4a4_a2q_32b'", | ||
choices=utils.model_impl.keys()) | ||
parser.add_argument( | ||
"--save-path", | ||
type=str, | ||
default="outputs/", | ||
help="Directory where to save checkpoints. Default: 'outputs/'") | ||
parser.add_argument( | ||
"--load-from-path", | ||
type=str, | ||
default=None, | ||
help="Optional local path to load torch checkpoint from. Default: None") | ||
parser.add_argument( | ||
"--num-workers", | ||
type=int, | ||
default=0, | ||
help="Number of workers for the dataloader to use. Default: 0") | ||
parser.add_argument( | ||
"--pin-memory", | ||
action="store_true", | ||
default=False, | ||
help="If true, pin memory for the dataloader.") | ||
parser.add_argument( | ||
"--batch-size", type=int, default=512, help="Batch size for the dataloader. Default: 512") | ||
parser.add_argument( | ||
"--save-torch-model", | ||
action="store_true", | ||
default=False, | ||
help="If true, save torch model to specified save path.") | ||
parser.add_argument( | ||
"--export-to-qonnx", action="store_true", default=False, help="If true, export model to QONNX.") | ||
|
||
SEED = 0 | ||
random.seed(SEED) | ||
np.random.seed(SEED) | ||
torch.manual_seed(SEED) | ||
|
||
# create a random input for graph tracing | ||
random_inp = torch.randn(1, 3, 32, 32) | ||
|
||
if __name__ == "__main__": | ||
|
||
args = parser.parse_args() | ||
|
||
config.JIT_ENABLED = not args.export_to_qonnx | ||
|
||
# Initialize dataloaders | ||
print(f"Loading CIFAR10 dataset from {args.data_root}...") | ||
trainloader, testloader = utils.get_cifar10_dataloaders( | ||
data_root=args.data_root, | ||
batch_size_train=args.batch_size, # does not matter here | ||
batch_size_test=args.batch_size, | ||
num_workers=args.num_workers, | ||
pin_memory=args.pin_memory) | ||
|
||
# if load-from-path is not specified, then use the pre-trained checkpoint | ||
model = utils.get_model_by_name(args.model_name, pretrained=args.load_from_path is None) | ||
if args.load_from_path is not None: | ||
# note that if you used bias correction, you may need to prepare the model for the | ||
# new biases that were introduced. See `utils.get_model_by_name` for more details. | ||
state_dict = torch.load(args.load_from_path, map_location="cpu") | ||
model.load_state_dict(state_dict) | ||
criterion = nn.CrossEntropyLoss() | ||
|
||
top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion) | ||
print(f"Final top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}") | ||
|
||
# save checkpoint | ||
os.makedirs(args.save_path, exist_ok=True) | ||
if args.save_torch_model: | ||
ckpt_path = f"{args.save_path}/{args.model_name}.pth" | ||
torch.save(model.state_dict(), ckpt_path) | ||
with open(ckpt_path, "rb") as _file: | ||
bytes = _file.read() | ||
model_tag = sha256(bytes).hexdigest()[:8] | ||
new_ckpt_path = f"{args.save_path}/{args.model_name}-{model_tag}.pth" | ||
os.rename(ckpt_path, new_ckpt_path) | ||
print(f"Saved model checkpoint to: {new_ckpt_path}") | ||
|
||
if args.export_to_qonnx: | ||
export_qonnx( | ||
model.cpu(), | ||
input_t=random_inp.cpu(), | ||
export_path=f"{args.save_path}/{args.model_name}-{model_tag}.onnx") |