Skip to content

Support >2G model export | torchlib(feat) #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -342,6 +343,18 @@ def _create_op_call_in_torch_graph(
return node_ouputs


def _estimate_tensor_size(tensor: torch.Tensor) -> int:
"""Estimate the size of a tensor in bytes.

Args:
tensor: The tensor to estimate the size of.

Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks GPT



class TorchScriptGraph:
_LOCAL_FUNCTION_DOMAIN_NAME: Final[str] = "torch_export"
"""The domain name for local functions."""
Expand Down Expand Up @@ -683,12 +696,15 @@ def to_model_proto(
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1

(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers_size = sum(
_estimate_tensor_size(tensor) for tensor in self.initializers.values()
)

# Treat models > 1GB as large models so that we have ample room
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Humm, maybe increase to 1.8 GB? I never see a model > 100MB without initializers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

# for the rest of the proto fields.
large_model = initializers_size > (2**30)

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers if include_initializers else {},
onnx_opset_version=opset_version,
# TODO(justinchuby): Figure out how to get the dynamic axes from the inputs
Expand All @@ -699,15 +715,30 @@ def to_model_proto(
keep_initializers_as_inputs=False,
custom_opsets={},
add_node_names=True,
# TODO(#493): Passing in this instead of reading from env.
# User must put the exported model file in the same folder to launch ORT.
onnx_file_path=os.path.join(
os.getenv("EXTERNAL_ONNX_INITIALIZER_FOLDER", ""), "dummy_model_path.onnx"
),
node_attr_to_name={},
)

onnx_model = onnx.load_from_string(proto)
cache_model_to_disk = include_initializers and large_model

if cache_model_to_disk:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether or not storing initializers should be controlled by a user flag. Assume that I export a 1GB model on remote machine. I want to visualize it locally. I really don't want to download its initializers with home internet. If this flag can be turned on, I will be able to just download the structure of model and debug faster.

Copy link
Collaborator Author

@justinchuby justinchuby Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think once the user get the model proto, they can do whatever they want (aka remove all the data)? A user has full control when they get the dynamo export output as an object.

Further yet include_initializers is already an argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @wschin 's goal and @justinchuby 's explanation. A thing to consider for ExportOutput.save or ExportOutputSerializer.

with tempfile.TemporaryDirectory() as temp_dir:
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
export_kwargs["onnx_file_path"] = onnx_file_path
_ = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_model(onnx_file_path, load_external_data=True)
else:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)

onnx_model.functions.extend(function_proto_dict.values())

# `_export_onnx` only exports opset_imports that is visible to it. It does not
Expand All @@ -725,10 +756,14 @@ def to_model_proto(
)

try:
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
if not cache_model_to_disk:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the follow up PR we will remove the checks altogether from a discussion with Aaron: we should not check it here.

# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For shape inference, can we still load shape and element type from model file (not initializer files) and then run infer_shape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but we also don’t need to because PyTorch supplies all the shape info.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A drawback due to onnx/onnx#5487, we don't have much inner node shape info left now that modules are functions.

# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
Expand Down