Skip to content

Commit

Permalink
Adding activation calibration
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Jan 31, 2024
1 parent d60542c commit cf69b97
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@
default=False,
help="If true, save torch model to specified save path.")
parser.add_argument(
"--apply-bias-corr",
"--apply-act-calibration",
action="store_true",
default=False,
help="If true, apply activation calibration to the quantized model.")
parser.add_argument(
"--apply-bias-correction",
action="store_true",
default=False,
help="If true, apply bias correction to the quantized model.")
Expand Down Expand Up @@ -137,7 +142,7 @@
model = utils.get_model_by_name(
args.model_name, init_from_float_checkpoint=args.from_float_checkpoint)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
optimizer = optim.SGD(
utils.filter_params(model.named_parameters(), args.weight_decay),
lr=args.lr_init,
weight_decay=args.weight_decay)
Expand All @@ -146,9 +151,14 @@
# Calibrate the quant model on the calibration dataset
if args.apply_ep_init:
print("Applying EP-init:")
utils.apply_ep_init(model, random_inp)
model = utils.apply_ep_init(model, random_inp)

# Calibrate the quant model on the calibration dataset
if args.apply_act_calibration:
print("Applying activation calibration:")
utils.apply_act_calibrate(calibloader, model)

if args.apply_bias_corr:
if args.apply_bias_correction:
print("Applying bias correction:")
utils.apply_bias_correction(calibloader, model)

Expand All @@ -173,7 +183,7 @@

model.load_state_dict(best_weights)
top_1, top_5, loss = utils.evaluate_topk_accuracies(testloader, model, criterion)
print(f"Final top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}")
print(f"Final: top_1={top_1:.1%}, top_5={top_5:.1%}, loss={loss:.3f}")

# save checkpoint
os.makedirs(args.save_path, exist_ok=True)
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/imagenet_classification/a2q/ep_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ def register_upper_bound(module: AccumulatorAwareParameterPreScaling, inp, outpu

for hook in hook_list:
hook.remove()

return model
16 changes: 15 additions & 1 deletion src/brevitas_examples/imagenet_classification/a2q/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling
from brevitas.function import abs_binary_sign_grad
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode

from .ep_init import apply_ep_init
from .quant import *
Expand All @@ -28,11 +29,12 @@

__all__ = [
"apply_ep_init",
"apply_act_calibrate",
"apply_bias_correction",
"get_model_by_name",
"filter_params",
"create_calibration_dataloader",
"get_cifar10_dataloaders",
"apply_bias_correction",
"train_for_epoch",
"evaluate_topk_accuracies"]

Expand Down Expand Up @@ -239,6 +241,18 @@ def get_cifar10_dataloaders(
return trainloader, testloader


def apply_act_calibrate(calib_loader, model):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with calibration_mode(model):
for images, _ in tqdm(calib_loader):
images = images.to(device)
images = images.to(dtype)
model(images)


def apply_bias_correction(calib_loader, model: nn.Module):
model.eval()
dtype = next(model.parameters()).dtype
Expand Down

0 comments on commit cf69b97

Please sign in to comment.