diff --git a/onnxscript/function_libs/torch_lib/graph_building.py b/onnxscript/function_libs/torch_lib/graph_building.py index 7d0fa59cb2..f2317a8c4c 100644 --- a/onnxscript/function_libs/torch_lib/graph_building.py +++ b/onnxscript/function_libs/torch_lib/graph_building.py @@ -1,9 +1,8 @@ """Graph building functions for torchscript graph backend.""" from __future__ import annotations +import ctypes import logging -import os -import tempfile import typing import warnings from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union @@ -65,9 +64,6 @@ None, ] -# Be sure to leave ample room for the rest of the proto fields. -_LARGE_MODEL_SIZE_THRESHOLD = int(2**30 * 1.8) # 1.8GB - # TODO(justinchuby): Build a context manager to handle source information. @@ -378,12 +374,38 @@ def eval_function( # type: ignore[override] return self._graph.add_function_call(function, inputs, attributes) +def _add_initializers( + model_proto: onnx.ModelProto, initializers: Mapping[str, torch.Tensor] +) -> None: + """Add initializers to the model proto.""" + tensor_protos = [] + + for name, tensor in initializers.items(): + tensor = tensor.detach().cpu().contiguous() + # Take the raw data directly from the tensor to avoid the overhead of + # data manipulation in onnx.helper.make_tensor + raw_data = bytes( + (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( + tensor.data_ptr() + ) + ) + tensor_proto = onnx.helper.make_tensor( + name=name, + data_type=_type_utils.JitScalarType.from_dtype(tensor.dtype).onnx_type(), + dims=tensor.shape, + vals=raw_data, + raw=True, + ) + tensor_protos.append(tensor_proto) + model_proto.graph.initializer.extend(tensor_protos) + + @runtime_typing.checked def _add_attribute_to_torchscript_node( node: torch.Node, key: str, value: Union[float, int, str, bytes, Sequence[float], Sequence[int], torch.Tensor], -): +) -> torch.Node: """Initializes the right attribute based on type of value.""" if isinstance(value, float): return node.f_(key, value) @@ -443,18 +465,6 @@ def _create_op_call_in_torch_graph( return node_ouputs -def _tensor_rawdata_size(tensor: torch.Tensor) -> int: - """Estimate the size of a tensor in bytes. - - Args: - tensor: The tensor to estimate the size of. - - Returns: - The estimated size of the tensor in bytes. - """ - return tensor.numel() * tensor.element_size() - - def _shared_functions() -> list[onnx.FunctionProto]: """Hack to always include the share ops.""" @@ -935,16 +945,13 @@ def to_model_proto( # TODO(BowenBao): All local function domain versions are hardcoded as 1. unique_custom_domains[function_proto.domain] = 1 - initializers_size = sum( - _tensor_rawdata_size(tensor) for tensor in self.initializers.values() - ) - - large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD - - export_kwargs: dict[str, Any] = dict( - initializers=self.initializers - if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS - else {}, + ( + proto, + _, + _, + _, + ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access + initializers={}, onnx_opset_version=opset_version, dynamic_axes={}, defer_weight_export=False, @@ -955,46 +962,12 @@ def to_model_proto( add_node_names=True, node_attr_to_name={}, ) - - # We decided to cache the model to disk when the model is large. - # Alternatively, we could build the ONNX `TensorProto`s in memory - # and append them to the model proto. - # We did not do it because it is harder to get right (vs. PyTorch's battle-tested - # implementation) and creating the `TensorProto`s naively (by converting to numpy) - # is slow. - cache_model_to_disk = large_model and include_initializers - - if cache_model_to_disk: - with tempfile.TemporaryDirectory() as temp_dir: - onnx_file_path = os.path.join(temp_dir, "exported_model.onnx") - export_kwargs["onnx_file_path"] = onnx_file_path - ( - proto, - _, - _, - _, - ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access - **export_kwargs - ) - onnx_model = onnx.load_from_string(proto) - onnx.load_external_data_for_model(onnx_model, temp_dir) - else: - ( - proto, - _, - _, - _, - ) = self._torch_graph._export_onnx( # type: ignore[attr-defined] # pylint: disable=protected-access - **export_kwargs - ) - onnx_model = onnx.load_from_string(proto) - + onnx_model = onnx.load_from_string(proto) onnx_model.functions.extend(function_proto_dict.values()) onnx_model.functions.extend(_shared_functions()) # Override value_infos with symbolic shapes. onnx_model = self._override_with_symbolic_value_info_proto(onnx_model) - # `_export_onnx` only exports opset_imports that is visible to it. It does not # export opset_imports for nested functions, since it does not have access to # them. We manually add them back and merge with existing opset_imports in the @@ -1015,16 +988,14 @@ def to_model_proto( common_ops.common_opset.domain, common_ops.common_opset.version ) ) - try: - if not cache_model_to_disk: - # Only check the model if it is in memory. - # Otherwise the checker and shape_inference will fail because - # we cannot serialize the model. - onnx_model = onnx.shape_inference.infer_shapes( - onnx_model, check_type=True, strict_mode=False, data_prop=True - ) - onnx.checker.check_model(onnx_model, full_check=True) + # Fill in the shape information before adding initializers, + # because the initializers may be too large (>2gb) for the model + # to fit in the memory. + onnx_model = onnx.shape_inference.infer_shapes( + onnx_model, check_type=True, strict_mode=False, data_prop=True + ) + onnx.checker.check_model(onnx_model, full_check=True) except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e: warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1) logging.debug( @@ -1032,4 +1003,8 @@ def to_model_proto( onnx.printer.to_text(onnx_model), self.torch_graph, ) + + if include_initializers: + _add_initializers(onnx_model, self.initializers) + return onnx_model