From 5286462d03c7c566ac8e3e7c3b47d95de26f9395 Mon Sep 17 00:00:00 2001 From: Ziming Miao Date: Mon, 2 Nov 2020 13:07:56 +0800 Subject: [PATCH] Update PyTorch readme (#101) * update pt readme * Update README.md Co-authored-by: Alisa Chen <45112983+AlisaChen98@users.noreply.github.com> --- models/pytorch2onnx/README.md | 34 ++++++++++++++++++++++++++---- models/pytorch2onnx/vgg16_model.py | 3 +++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/models/pytorch2onnx/README.md b/models/pytorch2onnx/README.md index ccd99c8a2..755103769 100644 --- a/models/pytorch2onnx/README.md +++ b/models/pytorch2onnx/README.md @@ -1,7 +1,33 @@ -# Freeze ONNX model from PyTorch +# PyTorch2ONNX -If you want to test NNFusion against PyTorch model, we recommend ONNX as intermediary, because PyTorch [onnx_exporter](https://pytorch.org/docs/stable/onnx.html) handles many dynamics and we don't want to reimplement them. +*Note:* In this section, we assume nnfusion cli has been installed as [Build Guide](https://github.com/microsoft/nnfusion/wiki/Build-Guide). -On PyTorch onnx_exporter, we build a simple wrapper to freeze and check(not fully implemented) whether NNFusion could accept the model. We provide a [VGG example](./vgg16_model.py) for this tool. +NNFusion leverages ONNX to support PyTorch. So this section focuses on how to freeze an ONNX model from PyTorch source code. You could get NNFusion supported ONNX ops [here](../../thirdparty/ngraph/src/nnfusion/frontend/onnx_import/ops_bridge.cpp). -Of course, you could freeze ONNX model from thirdparties, like the [transformer example](./bert_model.py). +## Freeze model by PyTorch ONNX exporter + +Please refer to PyTorch [onnx section](https://pytorch.org/docs/stable/onnx.html) to convert a PyTorch model to ONNX format, currently it already supports a great majority of deeplearning workloads. + +## Freeze model by NNFusion pt_freezer + +On PyTorch onnx_exporter, we build a simple wrapper called [pt_freezer](./pytorch_freezer.py), it wraps PyTorch onnx_exporter with control flow and op availability(not implemented yet) check. We provide a well self-explanatory [VGG example](./vgg16_model.py) for this tool: + +```bash +# step0: install prerequisites +apt update && sudo apt install python3-pip +pip3 install onnx torch torchvision + +# step1: freeze vgg16 model +python3 vgg16_model.py +``` + +## Freeze model from thirdparties + +Of course, you could freeze ONNX model from thirdparties, like [huggingface transformer](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb), which supports exporting to ONNX format. + +## Freezed ONNX models + +| model | nnf codegen flags | download link | +| ----------- | ----------- | ----------- | +| VGG16 | -f onnx | [vgg16.onnx](https://nnfusion.blob.core.windows.net/models/onnx/vgg16.onnx) | +| BERT_base | -f onnx -p 'batch:3;sequence:512' | [pt-bert-base-cased.onnx](https://nnfusion.blob.core.windows.net/models/onnx/bert/pt-bert-base-cased.onnx) | diff --git a/models/pytorch2onnx/vgg16_model.py b/models/pytorch2onnx/vgg16_model.py index f973904ae..f812a1872 100644 --- a/models/pytorch2onnx/vgg16_model.py +++ b/models/pytorch2onnx/vgg16_model.py @@ -8,10 +8,13 @@ def main(): + # define pytorch module model = torchvision.models.vgg16() + # define the pytorch model input/output description input_desc = [IODescription("data", [1, 3, 224, 224], torch.float32)] output_desc = [IODescription("logits", [1, 1000], torch.float32)] model_desc = ModelDescription(input_desc, output_desc) + # save the onnx model somewhere freezer = PTFreezer(model, model_desc) freezer.execute("./vgg16.onnx")