From 70b4e43700e6f5f6fd5921a84465ad78c62a0364 Mon Sep 17 00:00:00 2001 From: Huilin Qu Date: Wed, 28 Feb 2024 20:08:09 +0100 Subject: [PATCH] Fix onnx export when no dir is specified. --- setup.py | 2 +- weaver/train.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 22328109..e09b714c 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires.append(line) setup(name="weaver-core", - version='0.4.13', + version='0.4.14', description="A streamlined deep-learning framework for high energy physics", long_description_content_type="text/markdown", author="H. Qu, C. Li", diff --git a/weaver/train.py b/weaver/train.py index 388d0ef5..3dfdcb1a 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -334,6 +334,8 @@ def onnx(args): model = model.cpu() model.eval() + if not os.path.dirname(args.export_onnx): + args.export_onnx = os.path.join(os.path.dirname(model_path), args.export_onnx) os.makedirs(os.path.dirname(args.export_onnx), exist_ok=True) inputs = tuple( torch.ones(model_info['input_shapes'][k], dtype=torch.float32) for k in model_info['input_names']) @@ -879,7 +881,7 @@ def _main(args): del test_loader if args.predict_output: - if '/' not in args.predict_output: + if not os.path.dirname(predict_output): predict_output = os.path.join( os.path.dirname(args.model_prefix), 'predict_output', args.predict_output)