Skip to content
119 changes: 47 additions & 72 deletions onnxscript/function_libs/torch_lib/graph_building.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Graph building functions for torchscript graph backend."""
from __future__ import annotations

import ctypes
import logging
import os
import tempfile
import typing
import warnings
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -65,9 +64,6 @@
None,
]

# Be sure to leave ample room for the rest of the proto fields.
_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB

# TODO(justinchuby): Build a context manager to handle source information.


Expand Down Expand Up @@ -378,12 +374,38 @@ def eval_function( # type: ignore[override]
return self._graph.add_function_call(function, inputs, attributes)


def _add_initializers(
Copy link
Contributor

@BowenBao BowenBao Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @kunal-vaishnavi creating onnx tensorproto from torch tensor dataptr

model_proto: onnx.ModelProto, initializers: Mapping[str, torch.Tensor]
) -> None:
"""Add initializers to the model proto."""
tensor_protos = []

for name, tensor in initializers.items():
tensor = tensor.detach().cpu().contiguous()
# Take the raw data directly from the tensor to avoid the overhead of
# data manipulation in onnx.helper.make_tensor
raw_data = bytes(
(ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
tensor.data_ptr()
)
)
tensor_proto = onnx.helper.make_tensor(
name=name,
data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(),

Check failure

Code scanning / lintrunner

PYLINT/E0602

Undefined variable '_type_utils' (undefined-variable) See [undefined-variable](https://pylint.pycqa.org/en/latest/user_guide/messages/error/undefined-variable.html). To disable, use ` # pylint: disable=undefined-variable`

Check failure

Code scanning / lintrunner

MYPY/name-defined

Name "_type_utils" is not defined To disable, use ` # type: ignore[name-defined]`

Check failure

Code scanning / lintrunner

RUFF/F821

Undefined name `_type_utils`. See https://beta.ruff.rs/docs/rules/
dims=tensor.shape,
vals=raw_data,
raw=True,
)
tensor_protos.append(tensor_proto)
model_proto.graph.initializer.extend(tensor_protos)


@runtime_typing.checked
def _add_attribute_to_torchscript_node(
node: torch.Node,
key: str,
value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor],
):
) -> torch.Node:
"""Initializes the right attribute based on type of value."""
if isinstance(value, float):
return node.f_(key, value)
Expand Down Expand Up @@ -443,18 +465,6 @@ def _create_op_call_in_torch_graph(
return node_ouputs


def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
"""Estimate the size of a tensor in bytes.

Args:
tensor: The tensor to estimate the size of.

Returns:
The estimated size of the tensor in bytes.
"""
return tensor.numel() * tensor.element_size()


def _shared_functions() -> list[onnx.FunctionProto]:
"""Hack to always include the share ops."""

Expand Down Expand Up @@ -935,16 +945,13 @@ def to_model_proto(
# TODO(BowenBao): All local function domain versions are hardcoded as 1.
unique_custom_domains[function_proto.domain] = 1

initializers_size = sum(
_tensor_rawdata_size(tensor) for tensor in self.initializers.values()
)

large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD

export_kwargs: dict[str, Any] = dict(
initializers=self.initializers
if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
else {},
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
initializers={},
onnx_opset_version=opset_version,
dynamic_axes={},
defer_weight_export=False,
Expand All @@ -955,46 +962,12 @@ def to_model_proto(
add_node_names=True,
node_attr_to_name={},
)

# We decided to cache the model to disk when the model is large.
# Alternatively, we could build the ONNX `TensorProto`s in memory
# and append them to the model proto.
# We did not do it because it is harder to get right (vs. PyTorch's battle-tested
# implementation) and creating the `TensorProto`s naively (by converting to numpy)
# is slow.
cache_model_to_disk = large_model and include_initializers

if cache_model_to_disk:
with tempfile.TemporaryDirectory() as temp_dir:
onnx_file_path = os.path.join(temp_dir, "exported_model.onnx")
export_kwargs["onnx_file_path"] = onnx_file_path
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)
onnx.load_external_data_for_model(onnx_model, temp_dir)
else:
(
proto,
_,
_,
_,
) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access
**export_kwargs
)
onnx_model = onnx.load_from_string(proto)

onnx_model = onnx.load_from_string(proto)
onnx_model.functions.extend(function_proto_dict.values())
onnx_model.functions.extend(_shared_functions())

# Override value_infos with symbolic shapes.
onnx_model = self._override_with_symbolic_value_info_proto(onnx_model)

# `_export_onnx` only exports opset_imports that is visible to it. It does not
# export opset_imports for nested functions, since it does not have access to
# them. We manually add them back and merge with existing opset_imports in the
Expand All @@ -1015,21 +988,23 @@ def to_model_proto(
common_ops.common_opset.domain, common_ops.common_opset.version
)
)

try:
if not cache_model_to_disk:
# Only check the model if it is in memory.
# Otherwise the checker and shape_inference will fail because
# we cannot serialize the model.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
# Fill in the shape information before adding initializers,
# because the initializers may be too large (>2gb) for the model
# to fit in the memory.
onnx_model = onnx.shape_inference.infer_shapes(
onnx_model, check_type=True, strict_mode=False, data_prop=True
)
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)
logging.debug(
"ONNX model:\n%s\n\nTorchScript graph:\n%s",
onnx.printer.to_text(onnx_model),
self.torch_graph,
)

if include_initializers:
_add_initializers(onnx_model, self.initializers)

return onnx_model