Skip to content

Commit

Permalink
Update PyTorch readme (#101)
Browse files Browse the repository at this point in the history
* update pt readme

* Update README.md

Co-authored-by: Alisa Chen <[email protected]>
  • Loading branch information
mzmssg and AlisaChen98 authored Nov 2, 2020
1 parent 0a40b5a commit 5286462
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
34 changes: 30 additions & 4 deletions models/pytorch2onnx/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
# Freeze ONNX model from PyTorch
# PyTorch2ONNX

If you want to test NNFusion against PyTorch model, we recommend ONNX as intermediary, because PyTorch [onnx_exporter](https://pytorch.org/docs/stable/onnx.html) handles many dynamics and we don't want to reimplement them.
*Note:* In this section, we assume nnfusion cli has been installed as [Build Guide](https://github.com/microsoft/nnfusion/wiki/Build-Guide).

On PyTorch onnx_exporter, we build a simple wrapper to freeze and check(not fully implemented) whether NNFusion could accept the model. We provide a [VGG example](./vgg16_model.py) for this tool.
NNFusion leverages ONNX to support PyTorch. So this section focuses on how to freeze an ONNX model from PyTorch source code. You could get NNFusion supported ONNX ops [here](../../thirdparty/ngraph/src/nnfusion/frontend/onnx_import/ops_bridge.cpp).

Of course, you could freeze ONNX model from thirdparties, like the [transformer example](./bert_model.py).
## Freeze model by PyTorch ONNX exporter

Please refer to PyTorch [onnx section](https://pytorch.org/docs/stable/onnx.html) to convert a PyTorch model to ONNX format, currently it already supports a great majority of deeplearning workloads.

## Freeze model by NNFusion pt_freezer

On PyTorch onnx_exporter, we build a simple wrapper called [pt_freezer](./pytorch_freezer.py), it wraps PyTorch onnx_exporter with control flow and op availability(not implemented yet) check. We provide a well self-explanatory [VGG example](./vgg16_model.py) for this tool:

```bash
# step0: install prerequisites
apt update && sudo apt install python3-pip
pip3 install onnx torch torchvision

# step1: freeze vgg16 model
python3 vgg16_model.py
```

## Freeze model from thirdparties

Of course, you could freeze ONNX model from thirdparties, like [huggingface transformer](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb), which supports exporting to ONNX format.

## Freezed ONNX models

| model | nnf codegen flags | download link |
| ----------- | ----------- | ----------- |
| VGG16 | -f onnx | [vgg16.onnx](https://nnfusion.blob.core.windows.net/models/onnx/vgg16.onnx) |
| BERT_base | -f onnx -p 'batch:3;sequence:512' | [pt-bert-base-cased.onnx](https://nnfusion.blob.core.windows.net/models/onnx/bert/pt-bert-base-cased.onnx) |
3 changes: 3 additions & 0 deletions models/pytorch2onnx/vgg16_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@


def main():
# define pytorch module
model = torchvision.models.vgg16()
# define the pytorch model input/output description
input_desc = [IODescription("data", [1, 3, 224, 224], torch.float32)]
output_desc = [IODescription("logits", [1, 1000], torch.float32)]
model_desc = ModelDescription(input_desc, output_desc)
# save the onnx model somewhere
freezer = PTFreezer(model, model_desc)
freezer.execute("./vgg16.onnx")

Expand Down

0 comments on commit 5286462

Please sign in to comment.