Skip to content

[ONNX] Update API to torch.onnx.export(..., dynamo=True) #3220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
26 changes: 17 additions & 9 deletions beginner_source/onnx/export_simple_model_to_onnx_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
Export a PyTorch model to ONNX
==============================

**Author**: `Thiago Crepaldi <https://github.com/thiagocrepaldi>`_
**Author**: `Ti-Tai Wang <https://github.com/titaiwangms>`_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We typically don't remove the original author even if they left the project. POC for reviews/updates are managed through this issue and labels.


.. note::
As of PyTorch 2.1, there are two versions of ONNX Exporter.
As of PyTorch 2.5, there are two versions of ONNX Exporter.

* ``torch.onnx.dynamo_export`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export(..., dynamo=True)`` is the newest (still in beta) exporter based on the TorchDynamo technology released with PyTorch 2.0
* ``torch.onnx.export`` is based on TorchScript backend and has been available since PyTorch 1.2.0

"""
Expand All @@ -21,7 +21,7 @@
# In the `60 Minute Blitz <https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html>`_,
# we had the opportunity to learn about PyTorch at a high level and train a small neural network to classify images.
# In this tutorial, we are going to expand this to describe how to convert a model defined in PyTorch into the
# ONNX format using TorchDynamo and the ``torch.onnx.dynamo_export`` ONNX exporter.
# ONNX format using TorchDynamo and the ``torch.onnx.export(..., dynamo=True)`` ONNX exporter.
#
# While PyTorch is great for iterating on the development of models, the model can be deployed to production
# using different formats, including `ONNX <https://onnx.ai/>`_ (Open Neural Network Exchange)!
Expand Down Expand Up @@ -90,7 +90,16 @@ def forward(self, x):

torch_model = MyModel()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
onnx_program = torch.onnx.export(torch_model, torch_input, dynamo=True)

######################################################################
# 3.5. (Optional) Optimize the ONNX model
# ---------------------------------------
#
# The ONNX model can be optimized with constant folding, and elimination of redundant nodes.
# The optimization is done in-place, so the original ONNX model is modified.

onnx_program.optimize()

######################################################################
# As we can see, we didn't need any code change to the model.
Expand Down Expand Up @@ -127,7 +136,7 @@ def forward(self, x):
# Once Netron is open, we can drag and drop our ``my_image_classifier.onnx`` file into the browser or select it after
# clicking the **Open model** button.
#
# .. image:: ../../_static/img/onnx/image_clossifier_onnx_modelon_netron_web_ui.png
# .. image:: ../../_static/img/onnx/image_classifier_onnx_modelon_netron_web_ui.png
# :width: 50%
#
#
Expand Down Expand Up @@ -155,7 +164,7 @@ def forward(self, x):

import onnxruntime

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(torch_input)
onnx_input = [torch_input]
print(f"Input length: {len(onnx_input)}")
print(f"Sample input: {onnx_input}")

Expand All @@ -166,7 +175,7 @@ def to_numpy(tensor):

onnxruntime_input = {k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]

####################################################################
# 7. Compare the PyTorch results with the ones from the ONNX Runtime
Expand All @@ -179,7 +188,6 @@ def to_numpy(tensor):
# Before comparing the results, we need to convert the PyTorch's output to match ONNX's format.

torch_outputs = torch_model(torch_input)
torch_outputs = onnx_program.adapt_torch_outputs_to_onnx(torch_outputs)

assert len(torch_outputs) == len(onnxruntime_outputs)
for torch_output, onnxruntime_output in zip(torch_outputs, onnxruntime_outputs):
Expand Down
11 changes: 7 additions & 4 deletions beginner_source/onnx/intro_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
====================

Authors:
`Thiago Crepaldi <https://github.com/thiagocrepaldi>`_,
`Ti-Tai Wang <https://github.com/titaiwangms>`_, `Justin Chu <https://github.com/justinchuby>`_,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any particular reason to remove the original author?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He does not work in our team anymore.

and `Xavier Dupré <https://github.com/xadupre>`_.

`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
format for representing machine learning models. The ``torch.onnx`` module provides APIs to
Expand All @@ -19,8 +20,10 @@
including Microsoft's `ONNX Runtime <https://www.onnxruntime.ai>`_.

.. note::
Currently, there are two flavors of ONNX exporter APIs,
but this tutorial will focus on the ``torch.onnx.dynamo_export``.
Currently, the users can choose either through `TorchScript https://pytorch.org/docs/stable/jit.html`_ or
`ExportedProgram https://pytorch.org/docs/stable/export.html`_ to export the model to ONNX by the
boolean parameter dynamo in `torch.onnx.export <https://pytorch.org/docs/stable/generated/torch.onnx.export.html>`_.
In this tutorial, we will focus on the ExportedProgram approach.

The TorchDynamo engine is leveraged to hook into Python's frame evaluation API and dynamically rewrite its
bytecode into an `FX graph <https://pytorch.org/docs/stable/fx.html>`_.
Expand All @@ -33,7 +36,7 @@
Dependencies
------------

PyTorch 2.1.0 or newer is required.
PyTorch 2.5.0 or newer is required.

The ONNX exporter depends on extra Python packages:

Expand Down