Skip to content
117 changes: 47 additions & 70 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Graph building functions for torchscript graph backend."""
from __future__ import annotations

import ctypes
import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -64,9 +63,6 @@
None,
]

# Be sure to leave ample room for the rest of the proto fields.
_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB

# TODO(justinchuby): Build a context manager to handle source information.


Expand Down Expand Up @@ -281,12 +277,38 @@ def eval_function( # type: ignore[override]
return self._graph.add_function_call(function, inputs, attributes)


def _add_initializers(
Copy link
Contributor

@BowenBao BowenBao Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @kunal-vaishnavi creating onnx tensorproto from torch tensor dataptr

model_proto: onnx.ModelProto, initializers: Mapping[str, torch.Tensor]
) -> None:
"""Add initializers to the model proto."""
tensor_protos = []

for name, tensor in initializers.items():
tensor = tensor.detach().cpu().contiguous()
# Take the raw data directly from the tensor to avoid the overhead of
# data manipulation in onnx.helper.make_tensor
raw_data = bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)
tensor_proto = onnx.helper.make_tensor(
name=name,
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(),

Check failure

Code scanning / lintrunner

PYLINT/E0602

Undefined variable '_type_utils' (undefined-variable) See [undefined-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/error/undefined-variable.html). To disable, use ` # pylint: disable=undefined-variable`

Check failure

Code scanning / lintrunner

MYPY/name-defined

Name "_type_utils" is not defined To disable, use ` # type: ignore[name-defined]`

Check failure

Code scanning / lintrunner

RUFF/F821

Undefined name `_type_utils`. See https://beta.ruff.rs/docs/rules/
dims=tensor.shape,
vals=raw_data,
raw=True,
)
tensor_protos.append(tensor_proto)
model_proto.graph.initializer.extend(tensor_protos)


@runtime_typing.checked
def _add_attribute_to_torchscript_node(
node: torch.Node,
key: str,
value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor],
):
) -> torch.Node:
"""Initializes the right attribute based on type of value."""
if isinstance(value, float):
return node.f_(key, value)
Expand Down Expand Up @@ -346,18 +368,6 @@ def _create_op_call_in_torch_graph(
return node_ouputs


def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
"""Estimate the size of a tensor in bytes.

Args:
tensor: The tensor to estimate the size of.

Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()


class TorchScriptGraph:
_LOCAL_FUNCTION_DOMAIN_NAME: Final[str] = "torch_export"
"""The domain name for local functions."""
Expand Down Expand Up @@ -699,14 +709,13 @@ def to_model_proto(
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1

initializers_size = sum(
_tensor_rawdata_size(tensor) for tensor in self.initializers.values()
)

large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers if include_initializers else {},
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers={},
onnx_opset_version=opset_version,
dynamic_axes={},
defer_weight_export=False,
Expand All @@ -717,42 +726,8 @@ def to_model_proto(
add_node_names=True,
node_attr_to_name={},
)

# We decided to cache the model to disk when the model is large.
# Alternatively, we could build the ONNX `TensorProto`s in memory
# and append them to the model proto.
# We did not do it because it is harder to get right (vs. PyTorch's battle-tested
# implementation) and creating the `TensorProto`s naively (by converting to numpy)
# is slow.
cache_model_to_disk = include_initializers and large_model

if cache_model_to_disk:
with tempfile.TemporaryDirectory() as temp_dir:
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
export_kwargs["onnx_file_path"] = onnx_file_path
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)
onnx.load_external_data_for_model(onnx_model, temp_dir)
else:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)

onnx_model = onnx.load_from_string(proto)
onnx_model.functions.extend(function_proto_dict.values())

# `_export_onnx` only exports opset_imports that is visible to it. It does not
# export opset_imports for nested functions, since it does not have access to
# them. We manually add them back and merge with existing opset_imports in the
Expand All @@ -766,21 +741,23 @@ def to_model_proto(
for domain, version in unique_custom_domains.items()
]
)

try:
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
# Fill in the shape information before adding initializers,
# because the initializers may be too large (>2gb) for the model
# to fit in the memory.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
"ONNX model:\n%s\n\nTorchScript graph:\n%s",
onnxscript.proto2text(onnx_model),
self.torch_graph,
)

if include_initializers:
_add_initializers(onnx_model, self.initializers)

return onnx_model