-
Notifications
You must be signed in to change notification settings - Fork 65
Support >2G model export | torchlib(feat) #1003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
acbb78d
0b4bac8
bc7459e
0e015f0
ca54944
331c236
db58031
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
import logging | ||
import os | ||
import tempfile | ||
import typing | ||
import warnings | ||
from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Tuple, Union | ||
|
@@ -342,6 +343,18 @@ def _create_op_call_in_torch_graph( | |
return node_ouputs | ||
|
||
|
||
def _estimate_tensor_size(tensor: torch.Tensor) -> int: | ||
"""Estimate the size of a tensor in bytes. | ||
|
||
Args: | ||
tensor: The tensor to estimate the size of. | ||
|
||
Returns: | ||
The estimated size of the tensor in bytes. | ||
""" | ||
return tensor.numel() * tensor.element_size() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks GPT |
||
|
||
|
||
class TorchScriptGraph: | ||
_LOCAL_FUNCTION_DOMAIN_NAME: Final[str] = "torch_export" | ||
"""The domain name for local functions.""" | ||
|
@@ -683,12 +696,15 @@ def to_model_proto( | |
# TODO(BowenBao): All local function domain versions are hardcoded as 1. | ||
unique_custom_domains[function_proto.domain] = 1 | ||
|
||
( | ||
proto, | ||
_, | ||
_, | ||
_, | ||
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access | ||
initializers_size = sum( | ||
_estimate_tensor_size(tensor) for tensor in self.initializers.values() | ||
) | ||
|
||
# Treat models > 1GB as large models so that we have ample room | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Humm, maybe increase to 1.8 GB? I never see a model > 100MB without initializers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
# for the rest of the proto fields. | ||
large_model = initializers_size > (2**30) | ||
|
||
export_kwargs: dict[str, Any] = dict( | ||
initializers=self.initializers if include_initializers else {}, | ||
onnx_opset_version=opset_version, | ||
# TODO(justinchuby): Figure out how to get the dynamic axes from the inputs | ||
|
@@ -699,15 +715,30 @@ def to_model_proto( | |
keep_initializers_as_inputs=False, | ||
custom_opsets={}, | ||
add_node_names=True, | ||
# TODO(#493): Passing in this instead of reading from env. | ||
# User must put the exported model file in the same folder to launch ORT. | ||
onnx_file_path=os.path.join( | ||
os.getenv("EXTERNAL_ONNX_INITIALIZER_FOLDER", ""), "dummy_model_path.onnx" | ||
), | ||
node_attr_to_name={}, | ||
) | ||
|
||
onnx_model = onnx.load_from_string(proto) | ||
cache_model_to_disk = include_initializers and large_model | ||
|
||
if cache_model_to_disk: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whether or not storing initializers should be controlled by a user flag. Assume that I export a 1GB model on remote machine. I want to visualize it locally. I really don't want to download its initializers with home internet. If this flag can be turned on, I will be able to just download the structure of model and debug faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think once the user get the model proto, they can do whatever they want (aka remove all the data)? A user has full control when they get the dynamo export output as an object. Further yet include_initializers is already an argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with @wschin 's goal and @justinchuby 's explanation. A thing to consider for |
||
with tempfile.TemporaryDirectory() as temp_dir: | ||
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx") | ||
export_kwargs["onnx_file_path"] = onnx_file_path | ||
_ = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access | ||
**export_kwargs | ||
) | ||
onnx_model = onnx.load_model(onnx_file_path, load_external_data=True) | ||
else: | ||
( | ||
proto, | ||
_, | ||
_, | ||
_, | ||
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access | ||
**export_kwargs | ||
) | ||
onnx_model = onnx.load_from_string(proto) | ||
|
||
onnx_model.functions.extend(function_proto_dict.values()) | ||
|
||
# `_export_onnx` only exports opset_imports that is visible to it. It does not | ||
|
@@ -725,10 +756,14 @@ def to_model_proto( | |
) | ||
|
||
try: | ||
onnx_model = onnx.shape_inference.infer_shapes( | ||
onnx_model, check_type=True, strict_mode=False, data_prop=True | ||
) | ||
onnx.checker.check_model(onnx_model, full_check=True) | ||
if not cache_model_to_disk: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the follow up PR we will remove the checks altogether from a discussion with Aaron: we should not check it here. |
||
# Only check the model if it is in memory. | ||
# Otherwise the checker and shape_inference will fail because | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For shape inference, can we still load shape and element type from model file (not initializer files) and then run infer_shape? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could, but we also don’t need to because PyTorch supplies all the shape info. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A drawback due to onnx/onnx#5487, we don't have much inner node shape info left now that modules are functions. |
||
# we cannot serialize the model. | ||
onnx_model = onnx.shape_inference.infer_shapes( | ||
onnx_model, check_type=True, strict_mode=False, data_prop=True | ||
) | ||
onnx.checker.check_model(onnx_model, full_check=True) | ||
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: | ||
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1) | ||
logging.debug( | ||
|
Uh oh!
There was an error while loading. Please reload this page.