Skip to content

Commit

Permalink
Fix: Style to conform to precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Mar 21, 2024
1 parent d880e67 commit 9c05693
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/brevitas_examples/bnn_pynq/bnn_pynq_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def launch(cmd_args):

# Avoid creating new folders etc.
if args.evaluate:
args.dry_run = True # Comment out to export ONNX models from pre-trained
args.dry_run = True # Comment out to export ONNX models from pre-trained

# Init trainer
trainer = Trainer(args)
Expand Down
9 changes: 6 additions & 3 deletions src/brevitas_examples/bnn_pynq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from torchvision.datasets import CIFAR10
from torchvision.datasets import MNIST

from brevitas.export import export_onnx_qcdq, export_qonnx
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_qonnx

from .logger import EvalEpochMeters
from .logger import Logger
Expand Down Expand Up @@ -158,7 +159,8 @@ def __init__(self, args):
with open(path, "rb") as f:
bytes = f.read()
readable_hash = sha256(bytes).hexdigest()[:8]
new_path = os.path.join(self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash))
new_path = os.path.join(
self.checkpoints_dir_path, "{}-qonnx-{}.onnx".format(name, readable_hash))
os.rename(path, new_path)
self.logger.info("Exporting QONNX to {}".format(new_path))
if args.export_qcdq_onnx:
Expand All @@ -168,7 +170,8 @@ def __init__(self, args):
with open(path, "rb") as f:
bytes = f.read()
readable_hash = sha256(bytes).hexdigest()[:8]
new_path = os.path.join(self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash))
new_path = os.path.join(
self.checkpoints_dir_path, "{}-qcdq-{}.onnx".format(name, readable_hash))
os.rename(path, new_path)
self.logger.info("Exporting QCDQ ONNX to {}".format(new_path))

Expand Down

0 comments on commit 9c05693

Please sign in to comment.