Skip to content

Commit

Permalink
Feat: Added compile arg to BNN-PYNQ examples
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored and Giuseppe5 committed Aug 14, 2024
1 parent d7cfc04 commit 6ccdf7d
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/brevitas_examples/bnn_pynq/bnn_pynq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def parse_args(args):
parser.add_argument("--network", default="LFC_1W1A", type=str, help="neural network")
parser.add_argument("--pretrained", action='store_true', help="Load pretrained model")
parser.add_argument("--strict", action='store_true', help="Strict state dictionary loading")
parser.add_argument("--compile", action='store_true', help="Compile model with `torch.compile` (PyTorch version >=2 only)")
parser.add_argument(
"--state_dict_to_pth",
action='store_true',
Expand Down
3 changes: 3 additions & 0 deletions src/brevitas_examples/bnn_pynq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .models import model_with_cfg
from .models.losses import SqrHingeLoss

TORCH_GEQ_200 = parse(torch.__version__) >= parse("2.0.0")

class MirrorMNIST(MNIST):

Expand Down Expand Up @@ -64,6 +65,8 @@ class Trainer(object):
def __init__(self, args):

model, cfg = model_with_cfg(args.network, args.pretrained)
if args.compile and TORCH_GEQ_200:
model = torch.compile(model)

# Init arguments
self.args = args
Expand Down

0 comments on commit 6ccdf7d

Please sign in to comment.