Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] proposal for ONNX export feature #20

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ projects
# ignore both root data dir and root data symlink
# but not lower level data dirs (ie: thelper.data)
/data
# ignore other common runtime directories
/checkpoint[s]
/output[s]
/result[s]
/session[s]
*.tif
*.zip
./workflow[s]
Expand Down
31 changes: 28 additions & 3 deletions docs/src/use-cases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ So, first off, let's start by training a classification model using the followin
The above configuration essentially means that we will be training a ResNet model with
default settings on CIFAR10 using all 10 classes. You can launch the training process via::

$ thelper new <PATH_TO_CLASSIF_CIFAR10_CONFIG>.json <PATH_TO_OUTPUT_DIR>
$ thelper new -c <PATH_TO_CLASSIF_CIFAR10_CONFIG>.json -d <PATH_TO_OUTPUT_DIR>
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on these, there's probably a ton of other dead CLI examples due to the commandline changes, I should go around and fix them...


See the :ref:`[user guide] <user-guide-cli-new>` for more information on creating training
sessions. Once that's done, you should obtain a folder named ``classif-cifar10`` in your output
Expand All @@ -383,9 +383,9 @@ in a new checkpoint, we will use the following session configuration::
This configuration essentially specifies where to find the 'best' checkpoint for the model we
just trained, and how to export a trace of it. For more information on the export operation, refer
to :ref:`[the user guide] <user-guide-cli-export>`. We now provide the configuration as a JSON to
the CLI one more::
the CLI once more::

$ thelper export <PATH_TO_EXPORT_CONFIG>.json <PATH_TO_OUTPUT_DIR>
$ thelper export -c <PATH_TO_EXPORT_CONFIG>.json -d <PATH_TO_OUTPUT_DIR>

If everything goes well, ``<PATH_TO_OUTPUT_DIR>/export-classif-cifar10`` should now contain a checkpoint
with the exported model trace and all metadata required to reinstantiate it. Note that as of 2019/06,
Expand Down Expand Up @@ -414,6 +414,31 @@ configuration is given below::
}
}

Similarly to the above procedure, ONNX format export can be requested using ``onnx_`` prefixed parameters instead of
``trace_``. The configuration could look like the following::

{
"name": "export-classif-onnx",
"model": {
# if checkpoint was created by thelper framework:
"ckptdata": "<PATH_TO_OUTPUT_DIR>/classif-cifar10/checkpoints/ckpt.best.pth"
# or 'type', 'params' and 'weights' (see above) if checkpoint was created outside the framework
},
"export": {
"onnx_name": "test-export.onnx",
"onnx_input": "torch.rand(1, 3, 224, 224)"
}
}

Calling ``thelper export`` as previously but using this ONNX export configuration instead will generate the
corresponding ONNX model under ``<PATH_TO_OUTPUT_DIR>/export-classif-onnx`` if everything goes well. Remember that
only supported conversions between PyTorch and ONNX will work, so you must be mindful of whether the model you are
trying to export has any custom or unusual layers.

Please consider also that, as of the time of this writing, there is still no official way to import ONNX models into
PyTorch (see: `[PyTorch #21683 - Import ONNX model to Pytorch] <https://github.com/pytorch/pytorch/issues/21683>`_).
Therefore, the framework also cannot import such checkpoints for the time being.

For more information on model importation, refer to the documentation of :meth:`thelper.nn.utils.create_model`.

`[to top] <#use-cases>`_
29 changes: 27 additions & 2 deletions thelper/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def is_defined_dict(container, section):
tester.test()


def export_model(config, save_dir):
def export_model(config, save_dir, onnx=False):
"""Launches a model exportation session.

This function will export a model defined via a configuration file into a new checkpoint that can be
Expand All @@ -436,6 +436,9 @@ def export_model(config, save_dir):
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
onnx: indicate if model should be exported as ONNX format. If ``False``, ONNX could also be defined via the
corresponding parameter value in the 'export' section, or inferred by '.onnx' extension for 'ckpt_name'.
Parameter 'onnx_input' must be provided in 'export' section to generate ONNX model.

.. seealso::
| :func:`thelper.nn.utils.create_model`
Expand Down Expand Up @@ -483,6 +486,27 @@ def export_model(config, save_dir):
export_state["model"] = trace_name # will be loaded in thelper.utils.load_checkpoint
else:
export_state["model"] = model.state_dict() if save_raw else model
export_onnx = onnx or export_config.get("onnx", False) or ckpt_name.endswith(".onnx")
if export_onnx:
logger.info("detected ONNX format requested for export")
onnx_input = thelper.utils.get_key_def("onnx_input", export_config, default=None)
if isinstance(onnx_input, str):
onnx_input = eval(onnx_input)
if onnx_input is None:
logger.warning("onnx input required to export as ONNX but was missing from config, ONNX export skipped")
else:
ckpt_onnx = thelper.utils.get_key_def("onnx_name", export_config, default=None)
if not ckpt_onnx:
ckpt_onnx, _ = os.path.splitext(ckpt_name)
ckpt_onnx = ckpt_onnx + ".onnx"
config["export"]["ckpt_name"] = ckpt_onnx
config["export"]["onnx"] = True
config["model"]["ckpt_path"] = ckpt_onnx # for 'resume' and 'infer'
config["model"]["ckptdata"] = ckpt_onnx # for 'new'
export_state["model"] = ckpt_onnx
torch.onnx.export(model, onnx_input, os.path.join(save_dir, ckpt_onnx),
training=torch.onnx.TrainingMode.TRAINING) # enforce format to allow re-training
thelper.utils.save_config(config, os.path.join(save_dir, "config.export.json"))
torch.save(export_state, os.path.join(save_dir, ckpt_name))
logger.debug("all done")

Expand Down Expand Up @@ -526,6 +550,7 @@ def make_argparser():
export_ap = subparsers.add_parser("export", help="launches a model exportation session from a config file")
export_ap.add_argument("-c", "--config", required=True, type=str, help="path to the session configuration file (or session directory)")
export_ap.add_argument("-d", "--save-dir", required=True, type=str, help="path to the session output root directory")
export_ap.add_argument("--onnx", action="store_true", help="enforce model checkpoint export format to ONNX")
infer_ap = subparsers.add_parser("infer", help="creates a inference session from a config file")
infer_ap.add_argument("--ckpt-path", type=str, help="path to the checkpoint (or directory) to use for inference "
"(otherwise uses model checkpoint from configuration)")
Expand Down Expand Up @@ -614,7 +639,7 @@ def main(args=None, argparser=None):
elif args.mode == "annot":
annotate_data(config, args.save_dir)
elif args.mode == "export":
export_model(config, args.save_dir)
export_model(config, args.save_dir, args.onnx)
else: # if args.mode == "split":
split_data(config, args.save_dir)
return 0
Expand Down
7 changes: 6 additions & 1 deletion thelper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,12 @@ def load_checkpoint(ckpt, # type: thelper.typedefs.Checkpoi
if hasattr(ckpt, "name"):
logger.debug("parsing checkpoint provided via file object")
basepath = os.path.dirname(os.path.abspath(ckpt.name))
ckptdata = torch.load(ckpt, map_location=map_location)
if isinstance(ckpt, str) and ckpt.endswith(".onnx"):
# FIXME: this is where pytorch load from ONNX should be handled provided the feature gets implemented
# https://github.com/pytorch/pytorch/issues/21683
raise NotImplementedError("onnx to pytorch conversion not implemented")
else:
ckptdata = torch.load(ckpt, map_location=map_location)
if not isinstance(ckptdata, dict):
raise AssertionError("unexpected checkpoint data type")
if check_version:
Expand Down