From 865972f7a791bf7b42efbcd87c8402bd865b329e Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:51:53 +0800 Subject: [PATCH] Fix unsuitable default value to bundle_root in ckpt_export (#7124) Fixes #7123 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu --- monai/bundle/scripts.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 4607ef65b7..264d3fbf0e 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -1273,7 +1273,6 @@ def ckpt_export( config_file_, filepath_, ckpt_file_, - bundle_root_, net_id_, meta_file_, key_in_ckpt_, @@ -1285,7 +1284,6 @@ def ckpt_export( "config_file", filepath=None, ckpt_file=None, - bundle_root=os.getcwd(), net_id=None, meta_file=None, key_in_ckpt="", @@ -1293,18 +1291,23 @@ def ckpt_export( input_shape=None, converter_kwargs={}, ) + bundle_root = _args.get("bundle_root", os.getcwd()) parser = ConfigParser() - parser.read_config(f=config_file_) - meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_ - filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_ - ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ - if not os.path.exists(ckpt_file_): - raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_ if os.path.exists(meta_file_): parser.read_meta(f=meta_file_) + # the rest key-values in the _args are to override config content + for k, v in _args.items(): + parser[k] = v + + filepath_ = os.path.join(bundle_root, "models", "model.ts") if filepath_ is None else filepath_ + ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_ + if not os.path.exists(ckpt_file_): + raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".') + net_id_ = "network_def" if net_id_ is None else net_id_ try: parser.get_parsed_content(net_id_) @@ -1313,10 +1316,6 @@ def ckpt_export( f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".' ) from e - # the rest key-values in the _args are to override config content - for k, v in _args.items(): - parser[k] = v - # When export through torch.jit.trace without providing input_shape, will try to parse one from the parser. if (not input_shape_) and use_trace: input_shape_ = _get_fake_input_shape(parser=parser)