Skip to content

Support >2G model export - alternative implementation | torchlib(feat) #1004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
119 changes: 47 additions & 72 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Graph building functions for torchscript graph backend."""
from __future__ import annotations

import ctypes
import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -65,9 +64,6 @@
None,
]

# Be sure to leave ample room for the rest of the proto fields.
_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB

# TODO(justinchuby): Build a context manager to handle source information.


Expand Down Expand Up @@ -378,12 +374,38 @@
return self._graph.add_function_call(function, inputs, attributes)


def _add_initializers(
Copy link
Contributor

@BowenBao BowenBao Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @kunal-vaishnavi creating onnx tensorproto from torch tensor dataptr

model_proto: onnx.ModelProto, initializers: Mapping[str, torch.Tensor]
) -> None:
"""Add initializers to the model proto."""
tensor_protos = []

for name, tensor in initializers.items():
tensor = tensor.detach().cpu().contiguous()

Check warning on line 384 in onnxscript/function_libs/torch_lib/graph_building.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building.py#L384

Added line #L384 was not covered by tests
# Take the raw data directly from the tensor to avoid the overhead of
# data manipulation in onnx.helper.make_tensor
raw_data = bytes(

Check warning on line 387 in onnxscript/function_libs/torch_lib/graph_building.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building.py#L387

Added line #L387 was not covered by tests
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)
tensor_proto = onnx.helper.make_tensor(

Check warning on line 392 in onnxscript/function_libs/torch_lib/graph_building.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building.py#L392

Added line #L392 was not covered by tests
name=name,
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(),

Check failure

Code scanning / lintrunner

PYLINT/E0602

Undefined variable '_type_utils' (undefined-variable) See [undefined-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/error/undefined-variable.html). To disable, use ` # pylint: disable=undefined-variable`

Check failure

Code scanning / lintrunner

MYPY/name-defined

Name "_type_utils" is not defined To disable, use ` # type: ignore[name-defined]`

Check failure

Code scanning / lintrunner

RUFF/F821

Undefined name `_type_utils`. See https://beta.ruff.rs/docs/rules/
dims=tensor.shape,
vals=raw_data,
raw=True,
)
tensor_protos.append(tensor_proto)

Check warning on line 399 in onnxscript/function_libs/torch_lib/graph_building.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/graph_building.py#L399

Added line #L399 was not covered by tests
model_proto.graph.initializer.extend(tensor_protos)


@runtime_typing.checked
def _add_attribute_to_torchscript_node(
node: torch.Node,
key: str,
value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor],
):
) -> torch.Node:
"""Initializes the right attribute based on type of value."""
if isinstance(value, float):
return node.f_(key, value)
Expand Down Expand Up @@ -443,18 +465,6 @@
return node_ouputs


def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
"""Estimate the size of a tensor in bytes.

Args:
tensor: The tensor to estimate the size of.

Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()


def _shared_functions() -> list[onnx.FunctionProto]:
"""Hack to always include the share ops."""

Expand Down Expand Up @@ -935,16 +945,13 @@
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1

initializers_size = sum(
_tensor_rawdata_size(tensor) for tensor in self.initializers.values()
)

large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers
if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
else {},
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers={},
onnx_opset_version=opset_version,
dynamic_axes={},
defer_weight_export=False,
Expand All @@ -955,46 +962,12 @@
add_node_names=True,
node_attr_to_name={},
)

# We decided to cache the model to disk when the model is large.
# Alternatively, we could build the ONNX `TensorProto`s in memory
# and append them to the model proto.
# We did not do it because it is harder to get right (vs. PyTorch's battle-tested
# implementation) and creating the `TensorProto`s naively (by converting to numpy)
# is slow.
cache_model_to_disk = large_model and include_initializers

if cache_model_to_disk:
with tempfile.TemporaryDirectory() as temp_dir:
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
export_kwargs["onnx_file_path"] = onnx_file_path
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)
onnx.load_external_data_for_model(onnx_model, temp_dir)
else:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)

onnx_model = onnx.load_from_string(proto)
onnx_model.functions.extend(function_proto_dict.values())
onnx_model.functions.extend(_shared_functions())

# Override value_infos with symbolic shapes.
onnx_model = self._override_with_symbolic_value_info_proto(onnx_model)

# `_export_onnx` only exports opset_imports that is visible to it. It does not
# export opset_imports for nested functions, since it does not have access to
# them. We manually add them back and merge with existing opset_imports in the
Expand All @@ -1015,21 +988,23 @@
common_ops.common_opset.domain, common_ops.common_opset.version
)
)

try:
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
# Fill in the shape information before adding initializers,
# because the initializers may be too large (>2gb) for the model
# to fit in the memory.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
"ONNX model:\n%s\n\nTorchScript graph:\n%s",
onnx.printer.to_text(onnx_model),
self.torch_graph,
)

if include_initializers:
_add_initializers(onnx_model, self.initializers)

return onnx_model