Skip to content

Commit

Permalink
More compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 14, 2024
1 parent 6ccdf7d commit 337196a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/brevitas_examples/bnn_pynq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
from torchvision.datasets import CIFAR10
from torchvision.datasets import MNIST

from brevitas import torch_version

from .logger import EvalEpochMeters
from .logger import Logger
from .logger import TrainingEpochMeters
from .models import model_with_cfg
from .models.losses import SqrHingeLoss

TORCH_GEQ_200 = parse(torch.__version__) >= parse("2.0.0")
TORCH_GEQ_200 = torch_version >= parse("2.0.0")


class MirrorMNIST(MNIST):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings

import numpy as np
from packaging.version import parse
import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
Expand All @@ -16,6 +17,7 @@
import torch.utils.data.distributed
import torchvision

from brevitas import torch_version
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_torch_qcdq
from brevitas.graph.equalize import activation_equalization_mode
Expand All @@ -36,6 +38,8 @@
from brevitas_examples.imagenet_classification.utils import SEED
from brevitas_examples.imagenet_classification.utils import validate

TORCH_GEQ_200 = torch_version >= parse("2.0.0")

# Ignore warnings about __torch_function__
warnings.filterwarnings("ignore")

Expand Down Expand Up @@ -267,6 +271,7 @@ def parse_type(v, default_type):
'uint_sym_act_for_unsigned_values',
default=True,
help='Use unsigned act quant when possible (default: enabled)')
add_bool_arg(parser, 'compile', default=False, help='Enable torch.compile (Requires PyTorch>=2.0)')


def main():
Expand Down Expand Up @@ -469,7 +474,9 @@ def main():
if args.bias_corr:
print("Applying bias correction:")
apply_bias_correction(calib_loader, quant_model)

if args.compile and TORCH_GEQ_200:
print("Applying torch.compile")
model = torch.compile(model)
# Validate the quant_model on the validation dataloader
print("Starting validation:")
validate(val_loader, quant_model, stable=dtype != torch.bfloat16)
Expand Down
8 changes: 8 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@
'sharded_torchmlir_group_weight',
'sharded_packed_torchmlir_group_weight'],
help='Model export.')
parser.add_argument(
"--compile",
action='store_true',
help="Compile model with `torch.compile` (PyTorch version >=2 only)")


def set_seed(seed):
Expand Down Expand Up @@ -367,6 +371,10 @@ def main():
apply_bias_correction(model, calibration_loader)
print("Bias correction applied.")

if args.compile:
print("Applying compile")
model = torch.compile(model)

if args.eval:
print("Model eval...")
ppl = model_eval(model, val_data, args.seqlen)
Expand Down

0 comments on commit 337196a

Please sign in to comment.