From 8d0dcebdd67a596dd49f3023cbd57b554c63d054 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:07:50 -0700 Subject: [PATCH 01/16] Import from ir --- onnxscript/ir/README.md | 23 +- onnxscript/ir/__init__.py | 70 +- onnxscript/ir/_core.py | 3445 ----------------- onnxscript/ir/_core_test.py | 1732 --------- onnxscript/ir/_display.py | 49 - onnxscript/ir/_display_test.py | 22 - onnxscript/ir/_enums.py | 242 -- onnxscript/ir/_enums_test.py | 179 - onnxscript/ir/_graph_comparison.py | 23 - onnxscript/ir/_graph_containers.py | 263 -- onnxscript/ir/_io.py | 97 - onnxscript/ir/_io_test.py | 144 - onnxscript/ir/_linked_list.py | 283 -- onnxscript/ir/_linked_list_test.py | 387 -- onnxscript/ir/_metadata.py | 44 - onnxscript/ir/_name_authority.py | 72 - onnxscript/ir/_name_authority_test.py | 24 - onnxscript/ir/_polyfill.py | 25 - onnxscript/ir/_protocols.py | 610 --- onnxscript/ir/_tape.py | 213 - onnxscript/ir/_tape_test.py | 76 - onnxscript/ir/_type_casting.py | 106 - onnxscript/ir/_type_casting_test.py | 50 - onnxscript/ir/convenience.py | 34 - onnxscript/ir/external_data.py | 396 -- onnxscript/ir/external_data_test.py | 502 --- onnxscript/ir/passes/__init__.py | 12 +- onnxscript/ir/passes/_pass_infra.py | 289 -- onnxscript/ir/passes/_pass_infra_test.py | 39 - onnxscript/ir/passes/common/__init__.py | 16 +- onnxscript/ir/passes/common/_c_api_utils.py | 77 - .../common/clear_metadata_and_docstring.py | 60 - .../clear_metadata_and_docstring_test.py | 107 - .../ir/passes/common/constant_manipulation.py | 215 - .../common/constant_manipulation_test.py | 516 --- onnxscript/ir/passes/common/inliner.py | 331 -- onnxscript/ir/passes/common/inliner_test.py | 205 - onnxscript/ir/passes/common/onnx_checker.py | 57 - .../ir/passes/common/onnx_checker_test.py | 79 - .../ir/passes/common/shape_inference.py | 112 - .../ir/passes/common/shape_inference_test.py | 137 - .../ir/passes/common/topological_sort.py | 33 - .../ir/passes/common/topological_sort_test.py | 85 - onnxscript/ir/passes/common/unused_removal.py | 196 - .../ir/passes/common/unused_removal_test.py | 257 -- onnxscript/ir/serde.py | 1725 --------- onnxscript/ir/serde_test.py | 417 -- onnxscript/ir/tape.py | 15 - onnxscript/ir/tensor_adapters.py | 122 - onnxscript/ir/tensor_adapters_test.py | 85 - onnxscript/ir/traversal.py | 82 - onnxscript/ir/traversal_test.py | 81 - 52 files changed, 42 insertions(+), 14419 deletions(-) delete mode 100644 onnxscript/ir/_core.py delete mode 100644 onnxscript/ir/_core_test.py delete mode 100644 onnxscript/ir/_display.py delete mode 100644 onnxscript/ir/_display_test.py delete mode 100644 onnxscript/ir/_enums.py delete mode 100644 onnxscript/ir/_enums_test.py delete mode 100644 onnxscript/ir/_graph_comparison.py delete mode 100644 onnxscript/ir/_graph_containers.py delete mode 100644 onnxscript/ir/_io.py delete mode 100644 onnxscript/ir/_io_test.py delete mode 100644 onnxscript/ir/_linked_list.py delete mode 100644 onnxscript/ir/_linked_list_test.py delete mode 100644 onnxscript/ir/_metadata.py delete mode 100644 onnxscript/ir/_name_authority.py delete mode 100644 onnxscript/ir/_name_authority_test.py delete mode 100644 onnxscript/ir/_polyfill.py delete mode 100644 onnxscript/ir/_protocols.py delete mode 100644 onnxscript/ir/_tape.py delete mode 100644 onnxscript/ir/_tape_test.py delete mode 100644 onnxscript/ir/_type_casting.py delete mode 100644 onnxscript/ir/_type_casting_test.py delete mode 100644 onnxscript/ir/convenience.py delete mode 100644 onnxscript/ir/external_data.py delete mode 100644 onnxscript/ir/external_data_test.py delete mode 100644 onnxscript/ir/passes/_pass_infra.py delete mode 100644 onnxscript/ir/passes/_pass_infra_test.py delete mode 100644 onnxscript/ir/passes/common/_c_api_utils.py delete mode 100644 onnxscript/ir/passes/common/clear_metadata_and_docstring.py delete mode 100644 onnxscript/ir/passes/common/clear_metadata_and_docstring_test.py delete mode 100644 onnxscript/ir/passes/common/constant_manipulation.py delete mode 100644 onnxscript/ir/passes/common/constant_manipulation_test.py delete mode 100644 onnxscript/ir/passes/common/inliner.py delete mode 100644 onnxscript/ir/passes/common/inliner_test.py delete mode 100644 onnxscript/ir/passes/common/onnx_checker.py delete mode 100644 onnxscript/ir/passes/common/onnx_checker_test.py delete mode 100644 onnxscript/ir/passes/common/shape_inference.py delete mode 100644 onnxscript/ir/passes/common/shape_inference_test.py delete mode 100644 onnxscript/ir/passes/common/topological_sort.py delete mode 100644 onnxscript/ir/passes/common/topological_sort_test.py delete mode 100644 onnxscript/ir/passes/common/unused_removal.py delete mode 100644 onnxscript/ir/passes/common/unused_removal_test.py delete mode 100644 onnxscript/ir/serde.py delete mode 100644 onnxscript/ir/serde_test.py delete mode 100644 onnxscript/ir/tape.py delete mode 100644 onnxscript/ir/tensor_adapters.py delete mode 100644 onnxscript/ir/tensor_adapters_test.py delete mode 100644 onnxscript/ir/traversal.py delete mode 100644 onnxscript/ir/traversal_test.py diff --git a/onnxscript/ir/README.md b/onnxscript/ir/README.md index dae5c09a5b..21d5cd124d 100644 --- a/onnxscript/ir/README.md +++ b/onnxscript/ir/README.md @@ -1,22 +1,3 @@ -# ONNX IR +# Where is the code? -An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation. - -## Features ✨ - -- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them). -- Low memory footprint: mmap'ed external tensors; unified interface for ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size limitation. Zero copies. -- Straightforward access patterns: Access value information and traverse the graph topology at ease. -- Robust mutation: Create as many iterators as you like on the graph while mutating it. -- Speed: Performant graph manipulation, serialization/deserialization to Protobuf. -- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way. -- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format. - -## Code Organization 🗺️ - -- [`_protocols.py`](_protocols.py): Interfaces defined for all entities in the IR. -- [`_core.py`](_core.py): Implementation of the core entities in the IR, including `Model`, `Graph`, `Node`, `Value`, and others. -- [`_enums.py`](_enums.py): Definition of the type enums that correspond to the `DataType` and `AttributeType` in `onnx.proto`. -- [`_name_authority.py`](_name_authority.py): The authority for giving names to entities in the graph, used internally. -- [`_linked_list.py`](_linked_list.py): The data structure as the node container in the graph that supports robust iteration and mutation. Internal. -- [`_metadata.py`](_metadata.py): Metadata store for all entities in the IR. +The ONNX IR has migrated to https://github.com/onnx/ir-py as a standalone project. The original onnxscript APIs are aliased here for compatibility. diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index b5daebe235..3fa204b405 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -83,14 +83,15 @@ "save", ] -from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal -from onnxscript.ir._convenience._constructors import node, tensor -from onnxscript.ir._core import ( +from onnx_ir import ( + ArrayCompatible, Attr, AttrFloat32, AttrFloat32s, AttrGraph, AttrGraphs, + AttributeProtocol, + AttributeType, AttrInt64, AttrInt64s, AttrSparseTensor, @@ -101,58 +102,53 @@ AttrTensors, AttrTypeProto, AttrTypeProtos, + DataType, + DLPackCompatible, ExternalTensor, Function, + FunctionProtocol, Graph, + GraphProtocol, GraphView, + GraphViewProtocol, Input, LazyTensor, + MapTypeProtocol, Model, + ModelProtocol, Node, + NodeProtocol, + OperatorIdentifier, OptionalType, RefAttr, + ReferenceAttributeProtocol, SequenceType, Shape, + ShapeProtocol, + SparseTensorProtocol, SparseTensorType, StringTensor, SymbolicDim, + SymbolicDimProtocol, Tensor, + TensorProtocol, + TensorProtoTensor, TensorType, TypeAndShape, - Value, -) -from onnxscript.ir._enums import ( - AttributeType, - DataType, -) -from onnxscript.ir._io import load, save -from onnxscript.ir._protocols import ( - ArrayCompatible, - AttributeProtocol, - DLPackCompatible, - FunctionProtocol, - GraphProtocol, - GraphViewProtocol, - MapTypeProtocol, - ModelProtocol, - NodeProtocol, - OperatorIdentifier, - ReferenceAttributeProtocol, - ShapeProtocol, - SparseTensorProtocol, - SymbolicDimProtocol, - TensorProtocol, TypeProtocol, + Value, ValueProtocol, + convenience, + external_data, + from_onnx_text, + from_proto, + load, + node, + passes, + save, + serde, + tape, + tensor, + to_proto, + traversal, ) -from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto - - -def __set_module() -> None: - """Set the module of all functions in this module to this public module.""" - global_dict = globals() - for name in __all__: - global_dict[name].__module__ = __name__ - - -__set_module() diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py deleted file mode 100644 index 68f851808c..0000000000 --- a/onnxscript/ir/_core.py +++ /dev/null @@ -1,3445 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""data structures for the intermediate representation.""" - -# NOTES for developers: -# NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because -# We cannot assume that the build tool chain has protoc installed and would like -# to keep this module protobuf free. This way we separate the concerns of the IR -# and the serialization/deserialization. -# -# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead. - -from __future__ import annotations - -import abc -import contextlib -import dataclasses -import heapq -import math -import mmap -import os -import sys -import textwrap -import typing -from collections.abc import Hashable -from typing import ( - AbstractSet, - Any, - Callable, - Collection, - Generic, - Iterable, - Iterator, - MutableMapping, - MutableSequence, - NamedTuple, - OrderedDict, - Sequence, - SupportsInt, - Union, -) - -import ml_dtypes -import numpy as np -from typing_extensions import TypeIs - -import onnxscript -from onnxscript.ir import ( - _display, - _enums, - _graph_containers, - _linked_list, - _metadata, - _name_authority, - _protocols, - _type_casting, -) - -if typing.TYPE_CHECKING: - import numpy.typing as npt - from typing_extensions import TypeGuard - -TArrayCompatible = typing.TypeVar( - "TArrayCompatible", - bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible], -) - -# System is little endian -_IS_LITTLE_ENDIAN = sys.byteorder == "little" -# Data types that are not supported by numpy -_NON_NUMPY_NATIVE_TYPES = frozenset( - ( - _enums.DataType.BFLOAT16, - _enums.DataType.FLOAT8E4M3FN, - _enums.DataType.FLOAT8E4M3FNUZ, - _enums.DataType.FLOAT8E5M2, - _enums.DataType.FLOAT8E5M2FNUZ, - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - ) -) - - -def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]: - """Use this function to check if an object is compatible with numpy. - - Avoid isinstance checks with the ArrayCompatible protocol for performance reasons. - """ - return hasattr(obj, "__array__") - - -def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]: - """Use this function to check if an object is compatible with DLPack. - - Avoid isinstance checks with the DLPackCompatible protocol for performance reasons. - """ - return hasattr(obj, "__dlpack__") - - -class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable): - """Convenience Shared methods for classes implementing TensorProtocol.""" - - __slots__ = ( - "_doc_string", - "_metadata", - "_metadata_props", - "_name", - ) - - def __init__( - self, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._name: str | None = name - self._doc_string: str | None = doc_string - - def _printable_type_shape(self) -> str: - """Return a string representation of the shape and data type.""" - return f"{self.dtype},{self.shape}" - - def _repr_base(self) -> str: - """Base string for the repr method. - - Example: Tensor - """ - return f"{self.__class__.__name__}<{self._printable_type_shape()}>" - - @property - def name(self) -> str | None: - """The name of the tensor.""" - return self._name - - @name.setter - def name(self, value: str | None) -> None: - self._name = value - - @property - def doc_string(self) -> str | None: - """The documentation string.""" - return self._doc_string - - @doc_string.setter - def doc_string(self, value: str | None) -> None: - self._doc_string = value - - @property - def size(self) -> int: - """The number of elements in the tensor.""" - return math.prod(self.shape.numpy()) # type: ignore[attr-defined] - - @property - def nbytes(self) -> int: - """The number of bytes in the tensor.""" - # Use math.ceil because when dtype is INT4, the itemsize is 0.5 - return math.ceil(self.dtype.itemsize * self.size) - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - def display(self, *, page: bool = False) -> None: - rich = _display.require_rich() - - if rich is None: - status_manager = contextlib.nullcontext() - else: - import rich.status # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel - - status_manager = rich.status.Status(f"Computing tensor stats for {self!r}") - - from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel - asciichartpy, - ) - - with status_manager: - # Construct the text to display - lines = [] - array = self.numpy().flatten() - lines.append(repr(self)) - lines.append("") - nan_values = np.isnan(array) - nan_count = np.count_nonzero(nan_values) - inf_count = np.count_nonzero(np.isinf(array)) - numbers = array[~nan_values] - lines.append( - f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, " - f"NaN count: {nan_count}, " - f"Inf count: {inf_count}" - ) - # Compute sparsity - sparse_threathold = 1e-6 - # NOTE: count_nonzero() is faster than sum() for boolean arrays - sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size - lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}") - - # Compute histogram - finite_numbers = array[np.isfinite(array)] - lines.append("Histogram:") - hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False) - lines.append( - asciichartpy.plot( - hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"} - ) - ) - - text = "\n".join(lines) - - if rich is None: - print(text) - elif page: - import rich.console # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel - - console = rich.console.Console() - with console.pager(): - console.print(text) - else: - rich.print(text) - - -def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None: - """Check if the numpy array dtype matches the IR data type. - - When the dtype is not one of the numpy native dtypes, the value needs need to be: - - - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits. - - ``uint8`` for uint4 or float4. - - ``uint8`` for 8-bit data types. - - ``uint16`` for bfloat16 - - or corresponding dtypes from the ``ml_dtype`` package. - """ - if dtype in _NON_NUMPY_NATIVE_TYPES: - if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16): - raise TypeError( - f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype.itemsize == 1 and array.dtype not in ( - np.uint8, - ml_dtypes.float8_e4m3fnuz, - ml_dtypes.float8_e4m3fn, - ml_dtypes.float8_e5m2fnuz, - ml_dtypes.float8_e5m2, - ): - raise TypeError( - f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}." - ) - if dtype == _enums.DataType.INT4: - if array.dtype not in (np.int8, np.uint8, ml_dtypes.int4): - raise TypeError( - f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int4 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype == _enums.DataType.UINT4: - if array.dtype not in (np.uint8, ml_dtypes.uint4): - raise TypeError( - f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}." - ) - if dtype == _enums.DataType.FLOAT4E2M1: - if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn): - raise TypeError( - f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}." - ) - return - - try: - dtype_numpy = _enums.DataType.from_numpy(array.dtype) - except TypeError as e: - raise TypeError( - "Failed to convert the numpy dtype to an IR data type. " - "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when " - "creating a Tensor." - ) from e - - if dtype_numpy != dtype: - raise TypeError( - f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}." - ) - - -def _maybe_view_np_array_with_ml_dtypes( - array: np.ndarray, dtype: _enums.DataType -) -> np.ndarray: - """Reinterpret the array when it is a bit representation of a dtype not supported by numpy. - - Args: - array: The numpy array to reinterpret. - dtype: The data type to reinterpret the array as. - - Returns: - The array reinterpreted as the dtype. - """ - if dtype == _enums.DataType.BFLOAT16: - return array.view(ml_dtypes.bfloat16) - if dtype == _enums.DataType.FLOAT8E4M3FN: - return array.view(ml_dtypes.float8_e4m3fn) - if dtype == _enums.DataType.FLOAT8E4M3FNUZ: - return array.view(ml_dtypes.float8_e4m3fnuz) - if dtype == _enums.DataType.FLOAT8E5M2: - return array.view(ml_dtypes.float8_e5m2) - if dtype == _enums.DataType.FLOAT8E5M2FNUZ: - return array.view(ml_dtypes.float8_e5m2fnuz) - if dtype == _enums.DataType.INT4: - return array.view(ml_dtypes.int4) - if dtype == _enums.DataType.UINT4: - return array.view(ml_dtypes.uint4) - if dtype == _enums.DataType.FLOAT4E2M1: - return array.view(ml_dtypes.float4_e2m1fn) - return array - - -class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors - """An immutable concrete tensor. - - This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array - compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object. - The tensor is immutable and the data is not copied at initialization. - - To create a tensor from a numpy array:: - - >>> import numpy as np - >>> array = np.array([1, 2, 3]) - >>> tensor = Tensor(array) - >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method - >>> np.allclose(tensor, array) - True - - To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor - to a byte string for serialization, call :meth:`tobytes`. - - It is recommended to check the size of the tensor first before accessing the - underlying data, because accessing the data may be expensive and incur IO - overhead. - - Subclass this class to efficiently handle different types of tensors from different frameworks. - - Attributes: - name: The name of the tensor. - shape: The shape of the tensor. - dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum. - doc_string: Documentation string. - raw: The raw data behind this tensor. It can be anything. - size: The number of elements in the tensor. - nbytes: The number of bytes in the tensor. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "_dtype", - "_raw", - "_shape", - ) - - def __init__( - self, - value: TArrayCompatible, - dtype: _enums.DataType | None = None, - *, - shape: Shape | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - """Initialize a tensor. - - Args: - value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object. - When the dtype is not one of the numpy native dtypes, the value needs - to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16 - when the value is a numpy array; ``dtype`` must be specified in this case. - dtype: The data type of the tensor. It can be None only when value is a numpy array. - Users are responsible for making sure the dtype matches the value when value is not a numpy array. - shape: The shape of the tensor. If None, the shape is obtained from the value. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - - Raises: - TypeError: If the value is not a numpy array compatible or a DLPack compatible object. - TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array. - ValueError: If the shape is not specified and the value does not have a shape attribute. - ValueError: If the dtype is not specified and the value is not a numpy array. - """ - super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) - # NOTE: We should not do any copying here for performance reasons - if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value): - raise TypeError(f"Expected an array compatible object, got {type(value)}") - if shape is None: - # Obtain the shape from the value - if not hasattr(value, "shape"): - raise ValueError( - f"Expected an object with a shape attribute, but {type(value)} does not have shape. " - "Please specify the shape explicitly." - ) - self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009 - else: - self._shape = shape - self._shape.freeze() - if dtype is None: - if isinstance(value, np.ndarray): - self._dtype = _enums.DataType.from_numpy(value.dtype) - else: - raise ValueError( - "The dtype must be specified when the value is not a numpy array." - ) - else: - if isinstance(value, np.ndarray): - # Make sure the dtype matches the value - _check_numpy_representation_type(value, dtype) - # Users are responsible for making sure the dtype matches the value - # when value is not a numpy array - self._dtype = dtype - - # View the bfloat16, float8 and int4 types using ml_dtypes - if isinstance(value, np.ndarray): - value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment] - - self._raw = value - - def __array__(self, dtype: Any = None) -> np.ndarray: - if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw): - return self._raw.__array__(dtype) - assert _compatible_with_dlpack(self._raw), ( - f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}" - ) - return np.from_dlpack(self._raw) - - def __dlpack__(self, *, stream: Any = None) -> Any: - if _compatible_with_dlpack(self._raw): - return self._raw.__dlpack__(stream=stream) - return self.__array__().__dlpack__(stream=stream) - - def __dlpack_device__(self) -> tuple[int, int]: - if _compatible_with_dlpack(self._raw): - return self._raw.__dlpack_device__() - return self.__array__().__dlpack_device__() - - def __repr__(self) -> str: - # Avoid multi-line repr - tensor_lines = repr(self._raw).split("\n") - tensor_text = " ".join(line.strip() for line in tensor_lines) - return f"{self._repr_base()}({tensor_text}, name={self.name!r})" - - @property - def dtype(self) -> _enums.DataType: - """The data type of the tensor. Immutable.""" - return self._dtype - - @property - def shape(self) -> Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - @property - def raw(self) -> TArrayCompatible: - """Backing data of the tensor. Immutable.""" - return self._raw # type: ignore[return-value] - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - When the data type is not supported by numpy, the dtypes from the ``ml_dtype`` - package are used. The values can be reinterpreted as bit representations - using the ``.view()`` method. - """ - if isinstance(self._raw, np.ndarray): - return self._raw - # We do not cache the value to save memory - return self.__array__() - - def tobytes(self) -> bytes: - """Returns the value as bytes encoded in little endian. - - Override this method for more efficient serialization when the raw - value is not a numpy array. - """ - # TODO(justinchuby): Support DLPack - array = self.numpy() - if self.dtype in { - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - }: - # Pack the array into int4 - array = _type_casting.pack_int4(array) - else: - assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match" - if not _IS_LITTLE_ENDIAN: - array = array.view(array.dtype.newbyteorder("<")) - return array.tobytes() - - -class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors - """An immutable concrete tensor with its data store on disk. - - This class uses memory mapping to avoid loading the tensor into memory, - when the data type is supported by numpy. Otherwise, the tensor is loaded - into memory lazily when accessed. - - Calling :attr:`shape` does not incur IO. Checking shape before loading - the tensor is recommended if IO overhead and memory usage is a concern. - - To obtain an array, call :meth:`numpy`. To obtain the bytes, - call :meth:`tobytes`. - - The :attr:`location` must be a relative path conforming to the ONNX - specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed - to be the full path to the data file. Users should expect that the :attr:`path` - always leads to the correct file. At initialization, paths are not checked. - It is the user's responsibility to ensure the paths are valid and accessible. - - Attributes: - location: The location of the data file. It is the path relative to the base directory. - base_dir: The base directory for the external data. It is used to resolve relative paths. - At serialization, only the :attr:`location` is serialized into the "location" field of the ``TensorProto``. - path: The path to the data file. This is computed by joining :attr:`base_dir` and :attr:`location`. - offset: The offset in bytes from the start of the file. - length: The length of the data in bytes. - dtype: The data type of the tensor. - shape: The shape of the tensor. - name: The name of the tensor. It must be specified. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - - __slots__ = ( - "_array", - "_base_dir", - "_dtype", - "_length", - "_location", - "_offset", - "_shape", - "_valid", - "raw", - ) - - def __init__( - self, - location: os.PathLike | str, - offset: int | None, - length: int | None, - dtype: _enums.DataType, - *, - shape: Shape, - name: str, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - base_dir: os.PathLike | str = "", - ) -> None: - """Initialize an external tensor. - - Args: - location: The location of the data file. It is the path relative to the base directory. - offset: The offset in bytes from the start of the file. - length: The length of the data in bytes. - dtype: The data type of the tensor. - shape: The shape of the tensor. - name: The name of the tensor.. - doc_string: The documentation string. - metadata_props: The metadata properties. - base_dir: The base directory for the external data. It is used to resolve relative paths. - """ - super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) - # NOTE: Do not verify the location by default. This is because the location field - # in the tensor proto can be anything and we would like deserialization from - # proto to IR to not fail. - if onnxscript.DEBUG: - if os.path.isabs(location): - raise ValueError( - "The location must be a relative path. Please specify base_dir as well." - ) - self._location = location - self._base_dir = base_dir - self._offset: int | None = offset - self._length: int | None = length - self._dtype: _enums.DataType = dtype - self.name: str = name # mutable - self._shape: Shape = shape - self._shape.freeze() - self.doc_string: str | None = doc_string # mutable - self._array: np.ndarray | None = None - self.raw: mmap.mmap | None = None - self._metadata_props = metadata_props - self._metadata: _metadata.MetadataStore | None = None - self._valid = True - - @property - def base_dir(self) -> str | os.PathLike: - # Mutable - return self._base_dir - - @base_dir.setter - def base_dir(self, value: str | os.PathLike) -> None: - self._base_dir = value - - @property - def location(self) -> str | os.PathLike: - # Immutable - return self._location - - @property - def path(self) -> str: - # Immutable, computed - return os.path.join(self._base_dir, self._location) - - @property - def offset(self) -> int | None: - # Immutable - return self._offset - - @property - def length(self) -> int | None: - # Immutable - return self._length - - @property - def dtype(self) -> _enums.DataType: - # Immutable - return self._dtype - - @property - def shape(self) -> Shape: - # Immutable - return self._shape - - def _load(self): - self._check_validity() - assert self._array is None, "Bug: The array should be loaded only once." - if self.size == 0: - # When the size is 0, mmap is impossible and meaningless - self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy()) - return - # Map the whole file into the memory - # TODO(justinchuby): Verify if this would exhaust the memory address space - with open(self.path, "rb") as f: - self.raw = mmap.mmap( - f.fileno(), - 0, - access=mmap.ACCESS_READ, - ) - # Handle the byte order correctly by always using little endian - dt = np.dtype(self.dtype.numpy()).newbyteorder("<") - if self.dtype in { - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - }: - # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values - dt = np.dtype(np.uint8).newbyteorder("<") - count = self.size // 2 + self.size % 2 - else: - count = self.size - self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count) - shape = self.shape.numpy() - if self.dtype == _enums.DataType.INT4: - # Unpack the int4 arrays - self._array = _type_casting.unpack_int4(self._array, shape) - elif self.dtype == _enums.DataType.UINT4: - self._array = _type_casting.unpack_uint4(self._array, shape) - elif self.dtype == _enums.DataType.FLOAT4E2M1: - self._array = _type_casting.unpack_float4e2m1(self._array, shape) - else: - self._array = self._array.reshape(shape) - - def __array__(self, dtype: Any = None) -> np.ndarray: - self._check_validity() - if self._array is None: - self._load() - assert self._array is not None - return self._array.__array__(dtype) - - def __dlpack__(self, *, stream: Any = None) -> Any: - raise NotImplementedError( - "ExternalTensor does not support DLPack because it uses memory mapping. " - "Call numpy() to get a numpy array instead." - ) - - def __dlpack_device__(self) -> tuple[int, int]: - raise NotImplementedError( - "ExternalTensor does not support DLPack because it uses memory mapping. " - "Call numpy() to get a numpy array instead." - ) - - def __repr__(self) -> str: - return ( - f"{self._repr_base()}(location='{self.location}', name={self.name!r}, " - f"offset={self.offset!r}, length={self.length!r}, base_dir={self.base_dir!r})" - ) - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - The data will be memory mapped into memory and will not taken up physical memory space. - """ - self._check_validity() - if self._array is None: - self._load() - assert self._array is not None - return self._array - - def tobytes(self) -> bytes: - """Return the bytes of the tensor. - - This will load the tensor into memory. - """ - self._check_validity() - if self.raw is None: - self._load() - assert self.raw is not None - offset = self._offset or 0 - length = self._length or self.nbytes - return self.raw[offset : offset + length] - - def valid(self) -> bool: - """Check if the tensor is valid. - - The external tensor is valid if it has not been invalidated. - """ - return self._valid - - def _check_validity(self) -> None: - if not self.valid(): - raise ValueError( - f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted." - ) - - def invalidate(self) -> None: - """Invalidate the tensor. - - The external tensor is invalidated when the data is known to be corrupted or deleted. - """ - self._valid = False - - def release(self) -> None: - """Delete all references to the memory buffer and close the memory-mapped file.""" - self._array = None - if self.raw is not None: - self.raw.close() - self.raw = None - - -class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors - """Multidimensional array of strings (as binary data to match the string_data field in TensorProto).""" - - __slots__ = ( - "_raw", - "_shape", - ) - - def __init__( - self, - value: Sequence[bytes] | npt.NDArray[np.bytes_], - *, - shape: Shape | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - """Initialize a tensor. - - Args: - value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes. - shape: The shape of the tensor. If None, the shape is obtained from the value. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) - if shape is None: - if not hasattr(value, "shape"): - raise ValueError( - f"Expected an object with a shape attribute, but {type(value)} does not have shape. " - "Please specify the shape explicitly." - ) - self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009 - else: - self._shape = shape - self._shape.freeze() - self._raw = value - - def __array__(self, dtype: Any = None) -> np.ndarray: - if isinstance(self._raw, np.ndarray): - return self._raw - assert isinstance(self._raw, Sequence), ( - f"Bug: Expected a sequence, got {type(self._raw)}" - ) - return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy()) - - def __dlpack__(self, *, stream: Any = None) -> Any: - del stream # unused - raise TypeError("StringTensor does not support DLPack") - - def __dlpack_device__(self) -> tuple[int, int]: - raise TypeError("StringTensor does not support DLPack") - - def __repr__(self) -> str: - return f"{self._repr_base()}({self._raw!r}, name={self.name!r})" - - @property - def dtype(self) -> _enums.DataType: - """The data type of the tensor. Immutable.""" - return _enums.DataType.STRING - - @property - def shape(self) -> Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - @property - def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]: - """Backing data of the tensor. Immutable.""" - return self._raw # type: ignore[return-value] - - def numpy(self) -> npt.NDArray[np.bytes_]: - """Return the tensor as a numpy array.""" - return self.__array__() - - def tobytes(self) -> bytes: - raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.") - - def string_data(self) -> Sequence[bytes]: - """Return the string data of the tensor.""" - if isinstance(self._raw, np.ndarray): - return self._raw.flatten().tolist() - return self._raw - - -class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors - """A tensor that lazily evaluates a function to get the actual tensor. - - This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument. - The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called. - - Example:: - - >>> import numpy as np - >>> from onnxscript import ir - >>> weights = np.array([[1, 2, 3]]) - >>> def create_tensor(): # Delay applying transformations to the weights - ... weights_t = weights.transpose() - ... return ir.tensor(weights_t) - >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3])) - >>> print(lazy_tensor.numpy()) - [[1] - [2] - [3]] - - Attributes: - func: The function that returns the actual tensor. - dtype: The data type of the tensor. - shape: The shape of the tensor. - cache: Whether to cache the result of the function. If False, - the function is called every time the tensor content is accessed. - If True, the function is called only once and the result is cached in memory. - Default is False. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - - __slots__ = ( - "_dtype", - "_func", - "_shape", - "_tensor", - "cache", - ) - - def __init__( - self, - func: Callable[[], _protocols.TensorProtocol], - dtype: _enums.DataType, - shape: Shape, - *, - cache: bool = False, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> None: - """Initialize a lazy tensor. - - Args: - func: The function that returns the actual tensor. - dtype: The data type of the tensor. - shape: The shape of the tensor. - cache: Whether to cache the result of the function. - name: The name of the tensor. - doc_string: The documentation string. - metadata_props: The metadata properties. - """ - super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props) - self._func = func - self._dtype = dtype - self._shape = shape - self._tensor: _protocols.TensorProtocol | None = None - self.cache = cache - - def _evaluate(self) -> _protocols.TensorProtocol: - """Evaluate the function to get the actual tensor.""" - if not self.cache: - return self._func() - - # Cache the tensor - if self._tensor is None: - self._tensor = self._func() - return self._tensor - - def __array__(self, dtype: Any = None) -> np.ndarray: - return self._evaluate().__array__(dtype) - - def __dlpack__(self, *, stream: Any = None) -> Any: - return self._evaluate().__dlpack__(stream=stream) - - def __dlpack_device__(self) -> tuple[int, int]: - return self._evaluate().__dlpack_device__() - - def __repr__(self) -> str: - return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})" - - @property - def raw(self) -> Callable[[], _protocols.TensorProtocol]: - return self._func - - @property - def dtype(self) -> _enums.DataType: - """The data type of the tensor. Immutable.""" - return self._dtype - - @property - def shape(self) -> Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array.""" - return self._evaluate().numpy() - - def tobytes(self) -> bytes: - """Return the bytes of the tensor.""" - return self._evaluate().tobytes() - - -class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable): - """Immutable symbolic dimension that can be shared across multiple shapes.""" - - __slots__ = ("_value",) - - def __init__(self, value: str | None) -> None: - """Initialize a symbolic dimension. - - Args: - value: The value of the dimension. It should not be an int. - """ - if isinstance(value, int): - raise TypeError( - "The value of a SymbolicDim cannot be an int. " - "If you are creating a Shape, use int directly instead of SymbolicDim." - ) - self._value = value - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SymbolicDim): - return self.value == other - return self.value == other.value - - def __hash__(self) -> int: - return hash(self.value) - - @property - def value(self) -> str | None: - return self._value - - def __str__(self) -> str: - return f"{self._value}" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._value})" - - -def _is_int_compatible(value: object) -> TypeIs[SupportsInt]: - """Return True if the value is int compatible.""" - if isinstance(value, int): - return True - if hasattr(value, "__int__"): - # For performance reasons, we do not use isinstance(value, SupportsInt) - return True - return False - - -def _maybe_convert_to_symbolic_dim( - dim: int | SupportsInt | SymbolicDim | str | None, -) -> SymbolicDim | int: - """Convert the value to a SymbolicDim if it is not an int.""" - if dim is None or isinstance(dim, str): - return SymbolicDim(dim) - if _is_int_compatible(dim): - return int(dim) - if isinstance(dim, SymbolicDim): - return dim - raise TypeError( - f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'" - ) - - -class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): - """The shape of a tensor, including its dimensions and optional denotations. - - The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or - symbolic dimensions. - - A shape can be compared to another shape or plain Python list. - - A shape can be frozen (made immutable). When the shape is frozen, it cannot be - unfrozen, making it suitable to be shared across tensors or values. - Call :method:`freeze` to freeze the shape. - - To update the dimension of a frozen shape, call :method:`copy` to create a - new shape with the same dimensions that can be modified. - - Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations. - - Example:: - - >>> from onnxscript import ir - >>> shape = ir.Shape(["B", None, 3]) - >>> shape.rank() - 3 - >>> shape.is_static() - False - >>> shape.is_dynamic() - True - >>> shape.is_static(dim=2) - True - >>> shape[0] = 1 - >>> shape[1] = 2 - >>> shape.dims - (1, 2, 3) - >>> shape == [1, 2, 3] - True - >>> shape.frozen - False - >>> shape.freeze() - >>> shape.frozen - True - - Attributes: - dims: A tuple of dimensions representing the shape. - Each dimension can be an integer, None or a :class:`SymbolicDim`. - frozen: Indicates whether the shape is immutable. When frozen, the shape - cannot be modified or unfrozen. - """ - - __slots__ = ("_dims", "_frozen") - - def __init__( - self, - dims: Iterable[int | SupportsInt | SymbolicDim | str | None], - /, - denotations: Iterable[str | None] | None = None, - frozen: bool = False, - ) -> None: - """Initialize a shape. - - Args: - dims: The dimensions of the shape. Each dimension can be an integer or a - SymbolicDim or any Python object. When a ``dim`` is not an integer or a - SymbolicDim, it is converted to a SymbolicDim. - denotations: The denotations of the dimensions. If None, the denotations are not set. - Standard denotation can optionally be used to denote tensor - dimensions with standard semantic descriptions to ensure - that operations are applied to the correct axis of a tensor. - Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition - for pre-defined dimension denotations. - frozen: If True, the shape is immutable and cannot be modified. This - is useful when the shape is initialized by a Tensor or when the shape - is shared across multiple tensors. The default is False. - """ - self._dims: list[int | SymbolicDim] = [ - _maybe_convert_to_symbolic_dim(dim) for dim in dims - ] - self._denotations: list[str | None] = ( - list(denotations) if denotations is not None else [None] * len(self._dims) - ) - if len(self._denotations) != len(self._dims): - raise ValueError( - "The number of denotations, when provided, must be equal to the number of dimensions." - ) - self._frozen: bool = frozen - - @property - def dims(self) -> tuple[int | SymbolicDim, ...]: - """All dimensions in the shape. - - This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape. - """ - return tuple(self._dims) - - @property - def frozen(self) -> bool: - """Whether the shape is frozen. - - When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. - Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a - new shape with the same dimensions that can be modified. - """ - return self._frozen - - def freeze(self) -> None: - """Freeze the shape. - - When the shape is frozen, it cannot be unfrozen, making it suitable to be shared. - """ - self._frozen = True - - def copy(self, frozen: bool = False): - """Return a copy of the shape.""" - return Shape(self._dims, self._denotations, frozen=frozen) - - def rank(self) -> int: - """The rank of the tensor this shape represents.""" - return len(self._dims) - - def numpy(self) -> tuple[int, ...]: - if any(not isinstance(dim, int) for dim in self._dims): - raise ValueError(f"Cannot convert the shape {self} to a tuple of ints") - return tuple(dim for dim in self._dims) # type: ignore - - def __len__(self) -> int: - return len(self._dims) - - def __iter__(self) -> Iterator[int | SymbolicDim]: - return iter(self._dims) - - @typing.overload - def __getitem__(self, index: int) -> int | SymbolicDim: ... - - @typing.overload - def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ... - - def __getitem__(self, index): - return tuple(self._dims)[index] - - def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None: - """Set the dimension at the index. - - Args: - index: The index of the dimension. - value: The value of the dimension. - - Raises: - TypeError: If the shape is frozen and cannot be modified. - TypeError: If the value is not an int or SymbolicDim. - """ - if self._frozen: - raise TypeError("The shape is frozen and cannot be modified.") - - self._dims[index] = _maybe_convert_to_symbolic_dim(value) - - def get_denotation(self, index: int) -> str | None: - """Return the denotation of the dimension at the index. - - Args: - index: The index of the dimension. - - Returns: - The denotation of the dimension. - """ - return self._denotations[index] - - def set_denotation(self, index: int, denotation: str | None) -> None: - """Set the denotation of the dimension at the index. - - Args: - index: The index of the dimension. - denotation: The denotation of the dimension. - """ - self._denotations[index] = denotation - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._dims!r})" - - def __str__(self) -> str: - """Return a string representation of the shape. - - E.g. [n,1,3] - """ - return f"[{','.join([str(dim) for dim in self._dims])}]" - - def __eq__(self, other: object) -> bool: - """Return True if the shapes are equal. - - Two shapes are equal if all their dimensions are equal. - """ - if isinstance(other, Shape): - return self._dims == other._dims - if not isinstance(other, Iterable): - return False - return self._dims == list(other) - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - @typing.overload - def is_static(self, dim: int) -> bool: # noqa: D418 - """Return True if the dimension is static.""" - - @typing.overload - def is_static(self) -> bool: # noqa: D418 - """Return True if all dimensions are static.""" - - def is_static(self, dim=None) -> bool: - """Return True if the dimension is static. If dim is None, return True if all dimensions are static.""" - if dim is None: - return all(isinstance(dim, int) for dim in self._dims) - return isinstance(self[dim], int) - - @typing.overload - def is_dynamic(self, dim: int) -> bool: # noqa: D418 - """Return True if the dimension is dynamic.""" - - @typing.overload - def is_dynamic(self) -> bool: # noqa: D418 - """Return True if any dimension is dynamic.""" - - def is_dynamic(self, dim=None) -> bool: - if dim is None: - return not self.is_static() - return not self.is_static(dim) - - -def _quoted(string: str) -> str: - """Return a quoted string. - - This function is used to quote value/node names in the IR for better readability. - """ - return f'"{string}"' - - -class Usage(NamedTuple): - """A usage of a value in a node. - - Attributes: - node: The node that uses the value. - idx: The input index of the value in the node. - """ - - node: Node - idx: int - - -def _short_tensor_str_for_node(x: Value) -> str: - if x.const_value is None: - return "" - if x.const_value.size <= 10: - try: - data = x.const_value.numpy().tolist() - except Exception: # pylint: disable=broad-except - return "{...}" - return f"{{{data}}}" - return "{...}" - - -def _normalize_domain(domain: str) -> str: - """Normalize 'ai.onnx' to ''""" - return "" if domain == "ai.onnx" else domain - - -class Node(_protocols.NodeProtocol, _display.PrettyPrintable): - """IR Node. - - If the ``graph`` is provided, the node will be added to the graph. Otherwise, - user is responsible to call ``graph.append(node)`` (or other mutation methods - in :class:`Graph`) to add the node to the graph. - - After the node is initialized, it will add itself as a user of the input values. - - The output values of the node are created during node initialization and are immutable. - To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with - the new output values by calling :meth:`replace_input_with` on the using nodes - of this node's outputs. - - .. note: - When the ``domain`` is `"ai.onnx"`, it is normalized to `""`. - """ - - __slots__ = ( - "_attributes", - "_domain", - "_graph", - "_inputs", - "_metadata", - "_metadata_props", - "_name", - "_op_type", - "_outputs", - "_overload", - "_version", - "doc_string", - ) - - def __init__( - self, - domain: str, - op_type: str, - inputs: Iterable[Value | None], - attributes: Iterable[Attr | RefAttr] = (), - *, - overload: str = "", - num_outputs: int | None = None, - outputs: Sequence[Value] | None = None, - version: int | None = None, - graph: Graph | Function | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - """Initialize a node and add it as a user of the input values. - - Args: - domain: The domain of the operator. For onnx operators, this is an empty string. - When it is `"ai.onnx"`, it is normalized to `""`. - op_type: The name of the operator. - inputs: The input values. When an input is ``None``, it is an empty input. - attributes: The attributes. RefAttr can be used only when the node is defined in a Function. - overload: The overload name when the node is invoking a function. - num_outputs: The number of outputs of the node. If not specified, the number is 1. - outputs: The output values. If ``None``, the outputs are created during initialization. - version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph. - graph: The graph that the node belongs to. If ``None``, the node is not added to any graph. - A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph - of the function is assigned to the node. - name: The name of the node. If ``None``, the node is anonymous. The name may be - set by a :class:`Graph` if ``graph`` is specified. - doc_string: The documentation string. - metadata_props: The metadata properties. - - Raises: - TypeError: If the attributes are not :class:`Attr` or :class:`RefAttr`. - ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs. - ValueError: If an output value is ``None``, when outputs is specified. - ValueError: If an output value has a producer set already, when outputs is specified. - """ - self._name = name - self._domain: str = _normalize_domain(domain) - self._op_type: str = op_type - # NOTE: Make inputs immutable with the assumption that they are not mutated - # very often. This way all mutations can be tracked. - # If necessary, we can cache the inputs and outputs as tuples. - self._inputs: tuple[Value | None, ...] = tuple(inputs) - # Values belong to their defining nodes. The values list is immutable - self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs) - attributes = tuple(attributes) - if attributes and not isinstance(attributes[0], (Attr, RefAttr)): - raise TypeError( - f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. " - "If you are copying the attributes from another node, make sure you call " - "node.attributes.values() because it is a dictionary." - ) - self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict( - (attr.name, attr) for attr in attributes - ) - self._overload: str = overload - # TODO(justinchuby): Potentially support a version range - self._version: int | None = version - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - # _graph is set by graph.append - self._graph: Graph | None = None - # Add the node to the graph if graph is specified - if graph is not None: - graph.append(self) - self.doc_string = doc_string - - # Add the node as a use of the inputs - for i, input_value in enumerate(self._inputs): - if input_value is not None: - input_value._add_usage(self, i) # pylint: disable=protected-access - - def _create_outputs( - self, num_outputs: int | None, outputs: Sequence[Value] | None - ) -> tuple[Value, ...]: - """Check the parameters and create outputs for the node. - - Args: - num_outputs: The number of outputs of the node. - outputs: The output values of the node. - - Returns: - The output values of the node. - - Raises: - ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs. - ValueError: If an output value is None. - ValueError: If an output value has a producer set already. - """ - # Check num_outputs and outputs are consistent - if num_outputs is not None and outputs is not None and num_outputs != len(outputs): - raise ValueError( - "num_outputs must be the same as len(outputs) when num_outputs is specified." - f"num_outputs: {num_outputs}, outputs: {outputs}" - ) - # 1. If outputs is specified (can be empty []), use the outputs - if outputs is not None: - # Check all output values are valid first - for output in outputs: - if output is None: - raise ValueError(f"Output value cannot be None. All outputs: {outputs}") - if output.producer() is not None: - raise ValueError( - f"Supplied output value cannot have a producer when used for initializing a Node. " - f"Output: {output}. All outputs: {outputs}" - ) - result = [] - for i, output in enumerate(outputs): - output._producer = self # pylint: disable=protected-access - output._index = i # pylint: disable=protected-access - result.append(output) - return tuple(result) - - # 2. If num_outputs is specified, create num_outputs outputs - if num_outputs is None: - # Default to 1 output - num_outputs = 1 - assert num_outputs is not None - return tuple(Value(self, index=i) for i in range(num_outputs)) - - def __str__(self) -> str: - node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * ( - self._overload != "" - ) - inputs_text = ( - "(" - + ", ".join( - [ - ( - f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}" - if x is not None - else "None" - ) - for x in self._inputs - ] - ) - + ")" - ) - attributes_text = ( - (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}") - if self._attributes - else "" - ) - outputs_text = ", ".join(str(x) for x in self._outputs) - - return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}" - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, " - f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, " - f"overload={self._overload!r}, outputs={self._outputs!r}, " - f"version={self._version!r}, doc_string={self.doc_string!r})" - ) - - @property - def name(self) -> str | None: - """Optional name of the node.""" - return self._name - - @name.setter - def name(self, value: str | None) -> None: - self._name = value - - @property - def domain(self) -> str: - """The domain of the operator. For onnx operators, this is an empty string. - - .. note: - When domain is `"ai.onnx"`, it is normalized to `""`. - """ - return self._domain - - @domain.setter - def domain(self, value: str) -> None: - self._domain = _normalize_domain(value) - - @property - def version(self) -> int | None: - """Opset version of the operator called. - - If ``None``, the version is unspecified and will follow that of the graph. - This property is special to ONNX IR to allow mixed opset usage in a graph - for supporting more flexible graph transformations. It does not exist in the ONNX - serialization (protobuf) spec. - """ - return self._version - - @version.setter - def version(self, value: int | None) -> None: - self._version = value - - @property - def op_type(self) -> str: - """The name of the operator called.""" - return self._op_type - - @op_type.setter - def op_type(self, value: str) -> None: - self._op_type = value - - @property - def overload(self) -> str: - """The overload name when the node is invoking a function.""" - return self._overload - - @overload.setter - def overload(self, value: str) -> None: - self._overload = value - - @property - def inputs(self) -> Sequence[Value | None]: - """The input values of the node. - - The inputs are immutable. To change the inputs, create a new node and - replace the inputs of the using nodes of this node's outputs by calling - :meth:`replace_input_with` on the using nodes of this node's outputs. - """ - return self._inputs - - @inputs.setter - def inputs(self, _: Any) -> None: - raise AttributeError( - "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." - ) - - def predecessors(self) -> Sequence[Node]: - """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" - # Use the ordered nature of a dictionary to deduplicate the nodes - predecessors: dict[Node, None] = {} - for value in self.inputs: - if value is not None and (producer := value.producer()) is not None: - predecessors[producer] = None - return tuple(predecessors) - - def successors(self) -> Sequence[Node]: - """Return the successor nodes of the node, deduplicated, in a deterministic order.""" - # Use the ordered nature of a dictionary to deduplicate the nodes - successors: dict[Node, None] = {} - for value in self.outputs: - assert value is not None, "Bug: Output values are not expected to be None" - for usage in value.uses(): - successors[usage.node] = None - return tuple(successors) - - def replace_input_with(self, index: int, value: Value | None) -> None: - """Replace an input with a new value.""" - if index < 0 or index >= len(self.inputs): - raise ValueError(f"Index out of range: {index}") - old_input = self.inputs[index] - self._inputs = tuple( - value if i == index else old_input for i, old_input in enumerate(self.inputs) - ) - if old_input is not None: - old_input._remove_usage(self, index) # pylint: disable=protected-access - if value is not None: - value._add_usage(self, index) # pylint: disable=protected-access - - def prepend(self, /, nodes: Node | Iterable[Node]) -> None: - """Insert a node before this node in the list of nodes in the graph. - - It is the same as calling ``graph.insert_before(self, nodes)``. - - Example:: - - Before: previous_node -> self - previous_node' -> node -> next_node' - After: previous_node -> node -> self - previous_node' -> next_node' - - Args: - nodes: A node or a sequence of nodes to put before this node. - """ - if self._graph is None: - raise ValueError("The node to prepend to does not belong to any graph.") - self._graph.insert_before(self, nodes) - - def append(self, /, nodes: Node | Iterable[Node]) -> None: - """Insert a node after this node in the list of nodes in the graph. - - It is the same as calling ``graph.insert_after(self, nodes)``. - - Example:: - - Before: previous_node -> self - previous_node' -> node -> next_node' - After: previous_node -> self -> node - previous_node' -> next_node' - - Args: - nodes: A node or a sequence of nodes to put after this node. - """ - if self._graph is None: - raise ValueError("The node to append to does not belong to any graph.") - self._graph.insert_after(self, nodes) - - @property - def outputs(self) -> Sequence[Value]: - """The output values of the node. - - The outputs are immutable. To change the outputs, create a new node and - replace the inputs of the using nodes of this node's outputs by calling - :meth:`replace_input_with` on the using nodes of this node's outputs. - """ - return self._outputs - - @outputs.setter - def outputs(self, _: Sequence[Value]) -> None: - raise AttributeError("outputs is immutable. Please create a new node instead.") - - @property - def attributes(self) -> OrderedDict[str, Attr | RefAttr]: - """The attributes of the node.""" - return self._attributes - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - """The metadata properties of the node. - - The metadata properties are used to store additional information about the node. - Unlike ``meta``, this property is serialized to the ONNX proto. - """ - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - @property - def graph(self) -> Graph | None: - """The graph that the node belongs to. - - If the node is not added to any graph, this property is None. - """ - return self._graph - - @graph.setter - def graph(self, value: Graph | None) -> None: - self._graph = value - - def op_identifier(self) -> _protocols.OperatorIdentifier: - """Return the operator identifier of the node. - - The operator identifier is a tuple of the domain, op_type and overload. - """ - return self.domain, self.op_type, self.overload - - def display(self, *, page: bool = False) -> None: - """Pretty print the node. - - This method is used for debugging and visualization purposes. - """ - # Add the node's name to the displayed text - print(f"Node: {self.name!r}") - if self.doc_string: - print(f"Doc: {self.doc_string}") - super().display(page=page) - - -class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable): - """Tensor types that are non recursive types.""" - - __slots__ = ("_dtype", "denotation") - - def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None: - self._dtype = dtype - self.denotation = denotation - - @property - def dtype(self) -> _enums.DataType: - return self._dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - self._dtype = value - - @property - def elem_type(self) -> _enums.DataType: - """Return the element type of the tensor type""" - return self.dtype - - def __hash__(self) -> int: - return hash(repr(self)) - - def __eq__(self, other: object) -> bool: - if self.__class__ is not other.__class__: - return False - return self.dtype == other.dtype # type: ignore[attr-defined] - - def __repr__(self) -> str: - # Remove "Type" from name for display - short_name = self.__class__.__name__[:-4] - return f"{short_name}({self.dtype!r})" - - -class TensorType(_TensorTypeBase): - """A type that represents a tensor.""" - - def __str__(self) -> str: - return f"{self.dtype}" - - -class SparseTensorType(_TensorTypeBase): - """A type that represents a sparse tensor.""" - - -class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable): - """Base for recursive types like Optional and Sequence.""" - - __slots__ = ("_elem_type", "denotation") - - def __init__( - self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None - ) -> None: - self._elem_type = elem_type - self.denotation = denotation - - @property - def dtype(self) -> _enums.DataType: - return self._elem_type.dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - self._elem_type.dtype = value - - @property - def elem_type(self) -> _protocols.TypeProtocol: - return self._elem_type - - def __hash__(self) -> int: - return hash(repr(self)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _RecursiveTypeBase): - return False - if self.__class__ != other.__class__: - return False - # Recursively compare the type of the elements - return self.elem_type == other.elem_type - - def __repr__(self) -> str: - # Remove "Type" from name for display - short_name = self.__class__.__name__[:-4] - return f"{short_name}({self.elem_type!r})" - - -class SequenceType(_RecursiveTypeBase): - """A type that represents a sequence of elements.""" - - -class OptionalType(_RecursiveTypeBase): - """A type that represents an optional element.""" - - -class Value(_protocols.ValueProtocol, _display.PrettyPrintable): - """IR Value. - - A value is a named entity that can be used to represent an input or output of a graph, - a function, or a node. The information it stores generalizes over ``ValueInfoProto`` - in the ONNX specification. - - A :class:`Value` is always not owned or owned by exactly one node. When the value is not - owned, it must be an input of a graph or a function. ``producer`` and ``index`` - are ``None``. - - When the value is owned by a node, it is an output of the node. - The node that produces the value can be accessed with :meth:`producer`. - The index of the output of the node that produces the value can be accessed with - :meth:`index`. - - To find all the nodes that use this value as an input, call :meth:`uses`. - - To check if the value is an is an input, output or initializer of a graph, - use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`. - - Use :meth:`graph` to get the graph that owns the value. - """ - - __slots__ = ( - "_const_value", - "_graph", - "_index", - "_is_graph_input", - "_is_graph_output", - "_is_initializer", - "_metadata", - "_metadata_props", - "_name", - "_producer", - "_shape", - "_type", - "_uses", - "doc_string", - ) - - def __init__( - self, - producer: Node | None = None, - *, - index: int | None = None, - name: str | None = None, - shape: Shape | None = None, - type: _protocols.TypeProtocol | None = None, - doc_string: str | None = None, - const_value: _protocols.TensorProtocol | None = None, - ) -> None: - """Initialize a value. - - Args: - producer: The node that produces the value. - It can be ``None`` when the value is initialized first than its producer. - index: The index of the output of the defining node. - name: The name of the value. - shape: The shape of the value. - type: The type of the value. - doc_string: The documentation string. - const_value: The constant tensor if the value is constant. - """ - self._producer: Node | None = producer - self._index: int | None = index - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = None - - self._name: str | None = name - self._shape: Shape | None = shape - self._type: _protocols.TypeProtocol | None = type - # TODO(justinchuby): Handle initialization when a const value is provided - # We can get shape and type information from the const value - self._const_value = const_value - # Use a collection of (Node, int) to store uses. This is needed - # because a single use can use the same value multiple times. - # Use a dictionary to preserve insertion order so that the visiting order is deterministic - self._uses: dict[Usage, None] = {} - self.doc_string = doc_string - - # The graph this value belongs to. It is set *only* when the value is added as - # a graph input, output or initializer. - # The four properties can only be set by the Graph class (_GraphIO and GraphInitializers). - self._graph: Graph | None = None - self._is_graph_input: bool = False - self._is_graph_output: bool = False - self._is_initializer: bool = False - - def __repr__(self) -> str: - value_name = self.name if self.name else "anonymous:" + str(id(self)) - type_text = f", type={self.type!r}" if self.type is not None else "" - shape_text = f", shape={self.shape!r}" if self.shape is not None else "" - producer = self.producer() - if producer is None: - producer_text = "" - elif producer.name is not None: - producer_text = f", producer='{producer.name}'" - else: - producer_text = f", producer=anonymous_node:{id(producer)}" - index_text = f", index={self.index()}" if self.index() is not None else "" - const_value_text = self._constant_tensor_part() - if const_value_text: - const_value_text = f", const_value={const_value_text}" - return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})" - - def __str__(self) -> str: - value_name = self.name if self.name is not None else "anonymous:" + str(id(self)) - shape_text = str(self.shape) if self.shape is not None else "?" - type_text = str(self.type) if self.type is not None else "?" - - # Quote the name because in reality the names can have invalid characters - # that make them hard to read - return ( - f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}" - ) - - def _constant_tensor_part(self) -> str: - """Display string for the constant tensor attached to str of Value.""" - if self.const_value is not None: - # Only display when the const value is small - if self.const_value.size <= 10: - return f"{{{self.const_value}}}" - else: - return f"{{{self.const_value.__class__.__name__}(...)}}" - return "" - - @property - def graph(self) -> Graph | None: - """Return the graph that defines this value. - - When the value is an input/output/initializer of a graph, the owning graph - is that graph. When the value is an output of a node, the owning graph is the - graph that the node belongs to. When the value is not owned by any graph, - it returns ``None``. - """ - if self._graph is not None: - return self._graph - if self._producer is not None: - return self._producer.graph - return None - - def _owned_by_graph(self) -> bool: - """Return True if the value is owned by a graph.""" - result = self._is_graph_input or self._is_graph_output or self._is_initializer - if result: - assert self._graph is not None - return result - - def producer(self) -> Node | None: - """The node that produces this value. - - When producer is ``None``, the value does not belong to a node, and is - typically a graph input or an initializer. You can use :meth:`graph`` - to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output` - or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph. - """ - return self._producer - - def consumers(self) -> Sequence[Node]: - """Return the nodes (deduplicated) that consume this value.""" - return tuple({usage.node: None for usage in self._uses}) - - def index(self) -> int | None: - """The index of the output of the defining node.""" - return self._index - - def uses(self) -> Collection[Usage]: - """Return a set of uses of the value. - - The set contains tuples of ``(Node, index)`` where the index is the index of the input - of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. - """ - # Create a tuple for the collection so that iteration on will will not - # be affected when the usage changes during graph mutation. - # This adds a small overhead but is better a user experience than - # having users call tuple(). - return tuple(self._uses) - - def _add_usage(self, use: Node, index: int) -> None: - """Add a usage of this value. - - This is an internal method. It should only be called by the Node class. - """ - self._uses[Usage(use, index)] = None - - def _remove_usage(self, use: Node, index: int) -> None: - """Remove a node from the uses of this value. - - This is an internal method. It should only be called by the Node class. - """ - self._uses.pop(Usage(use, index)) - - @property - def name(self) -> str | None: - return self._name - - @name.setter - def name(self, value: str | None) -> None: - if self._const_value is not None: - self._const_value.name = value - self._name = value - - @property - def type(self) -> _protocols.TypeProtocol | None: - """The type of the tensor. - - Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``. - To obtain the data type of the tensor, use ``type.dtype`` or conveniently - :attr:`dtype`. - """ - return self._type - - @type.setter - def type(self, value: _protocols.TypeProtocol | None) -> None: - self._type = value - - @property - def dtype(self) -> _enums.DataType | None: - """The data type of the tensor.""" - if self._type is None: - return None - return self._type.dtype - - @dtype.setter - def dtype(self, value: _enums.DataType) -> None: - """Set the data type of the tensor. - - If the type is not set, it will be initialized to a new TensorType. To - set the type as other types like ``SequenceType``, initialize the type - then set :attr:`type` instead. - """ - if self._type is None: - self._type = TensorType(value) - else: - self._type.dtype = value - - @property - def shape(self) -> Shape | None: - return self._shape - - @shape.setter - def shape(self, value: Shape | None) -> None: - if value is None: - self._shape = None - return - if isinstance(value, Shape): - self._shape = value - return - raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'") - - @property - def const_value( - self, - ) -> _protocols.TensorProtocol | None: - """A concrete value. - - The value can be backed by different raw data types, such as numpy arrays. - The only guarantee is that it conforms TensorProtocol. - """ - return self._const_value - - @const_value.setter - def const_value( - self, - value: _protocols.TensorProtocol | None, - ) -> None: - if onnxscript.DEBUG: - if value is not None and not isinstance(value, _protocols.TensorProtocol): - raise TypeError( - f"Expected value to be a TensorProtocol or None, got '{type(value)}'" - ) - self._const_value = value - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def is_graph_input(self) -> bool: - """Whether the value is an input of a graph.""" - return self._is_graph_input - - def is_graph_output(self) -> bool: - """Whether the value is an output of a graph.""" - return self._is_graph_output - - def is_initializer(self) -> bool: - """Whether the value is an initializer of a graph.""" - return self._is_initializer - - -def Input( - name: str | None = None, - shape: Shape | None = None, - type: _protocols.TypeProtocol | None = None, - doc_string: str | None = None, -) -> Value: - """Create an input of a Graph or a Function. - - This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``. - """ - - # NOTE: The function name is capitalized to maintain API backward compatibility. - - return Value(name=name, shape=shape, type=type, doc_string=doc_string) - - -def _check_node_safe_to_remove( - node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value] -) -> None: - """Check if a node is safe to remove. - - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - - This check is typically O(1) assuming the number of uses of the node is small - - Args: - node: The node to check. - to_remove: A set of nodes that are to be removed. - This set is used to check if the node is still being used by other - nodes that are not to be removed. - graph_outputs: A set of values that are outputs of the graph. - - Raises: - ValueError: If the node does not belong to this graph or if there are users of the node. - ValueError: If the node is still being used by other nodes not to be removed. - """ - for output in node.outputs: - if output in graph_outputs: - raise ValueError( - f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True." - ) - uses_not_to_remove = [user for user, _ in output.uses() if user not in to_remove] - if uses_not_to_remove: - raise ValueError( - f"Output value '{output!r}' is still being used by other nodes that are not to be " - f"removed. All of its users that is not being removed: {uses_not_to_remove!r}. " - "Please make sure these nodes are no longer using the output value." - ) - - -class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable): - """IR Graph. - - Graph represents a computation graph. In addition to the ONNX specification - specified fields, it also contains a mapping of :attr:`opset_imports`. This - allows different subgraphs to import different opsets. It is the responsibility - of the deserializer to reconcile the different opsets. - - The `nodes` are not guaranteed to be topologically sorted. But the - iteration order should be deterministic across different runs. It is the - responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Graph. The Graph can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(graph)``. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "_doc_string", - "_initializers", - "_inputs", - "_metadata", - "_metadata_props", - "_name_authority", - "_nodes", - "_opset_imports", - "_outputs", - "name", - ) - - def __init__( - self, - inputs: Sequence[Value], - outputs: Sequence[Value], - *, - nodes: Iterable[Node], - initializers: Sequence[Value] = (), - doc_string: str | None = None, - opset_imports: dict[str, int] | None = None, - name: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - self.name = name - - # Private fields that are not to be accessed by any other classes - self._inputs = _graph_containers.GraphInputs(self, inputs) - self._outputs = _graph_containers.GraphOutputs(self, outputs) - self._initializers = _graph_containers.GraphInitializers(self) - for initializer in initializers: - if isinstance(initializer, str): - raise TypeError( - "Initializer must be a Value, not a string. " - "If you are copying the initializers from another graph, " - "make sure you call graph.initializers.values() because it is a dictionary." - ) - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self._initializers[initializer.name] = initializer - self._doc_string = doc_string - self._opset_imports = opset_imports or {} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet() - # Be sure the initialize the name authority before extending the nodes - # because it is used to name the nodes and their outputs - self._name_authority = _name_authority.NameAuthority() - # TODO(justinchuby): Trigger again if inputs or initializers are modified. - self._set_input_and_initializer_value_names_into_name_authority() - # Call self.extend not self._nodes.extend so the graph reference is added to the nodes - self.extend(nodes) - - @property - def inputs(self) -> MutableSequence[Value]: - return self._inputs - - @property - def outputs(self) -> MutableSequence[Value]: - return self._outputs - - @property - def initializers(self) -> MutableMapping[str, Value]: - return self._initializers - - def register_initializer(self, value: Value) -> None: - """Register an initializer to the graph. - - This is a convenience method to register an initializer to the graph with - checks. - - Args: - value: The :class:`Value` to register as an initializer of the graph. - It must have its ``.const_value`` set. - - Raises: - ValueError: If a value of the same name that is not this value - is already registered. - ValueError: If the value does not have a name. - ValueError: If the initializer is produced by a node. - ValueError: If the value does not have its ``.const_value`` set. - """ - if not value.name: - raise ValueError(f"Initializer must have a name: {value!r}") - if value.name in self._initializers: - if self._initializers[value.name] is not value: - raise ValueError( - f"Initializer '{value.name}' is already registered, but" - " it is not the same object: existing={self._initializers[value.name]!r}," - f" new={value!r}" - ) - if value.producer() is not None: - raise ValueError( - f"Value '{value!r}' is produced by a node and cannot be an initializer." - ) - if value.const_value is None: - raise ValueError( - f"Value '{value!r}' must have its const_value set to be an initializer." - ) - self._initializers[value.name] = value - - @property - def doc_string(self) -> str | None: - return self._doc_string - - @doc_string.setter - def doc_string(self, value: str | None) -> None: - self._doc_string = value - - @property - def opset_imports(self) -> dict[str, int]: - return self._opset_imports - - @typing.overload - def __getitem__(self, index: int) -> Node: ... - @typing.overload - def __getitem__(self, index: slice) -> Sequence[Node]: ... - - def __getitem__(self, index): - return self._nodes[index] - - def __len__(self) -> int: - return len(self._nodes) - - def __iter__(self) -> Iterator[Node]: - return iter(self._nodes) - - def __reversed__(self) -> Iterator[Node]: - return reversed(self._nodes) - - def _set_input_and_initializer_value_names_into_name_authority(self): - for value in self.inputs: - self._name_authority.register_or_name_value(value) - for value in self.initializers.values(): - self._name_authority.register_or_name_value(value) - - def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node: - """Set the graph reference for the node and assign names to it and its outputs if they don't have one.""" - if node.graph is not None and node.graph is not self: - raise ValueError( - f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()." - ) - # Give the node and its output values names if they don't not have one - self._name_authority.register_or_name_node(node) - for value in node._outputs: # pylint: disable=protected-access - self._name_authority.register_or_name_value(value) - node.graph = self - return node - - def node(self, index_or_name: int | str, /) -> Node: - """Get a node by index or name. - - This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1). - - .. note:: - If you need repeated random access, consider turning it into a list with ``list(graph)`` . - Or a dictionary for repeated access by name: ``{node.name for node in graph}`` . - - When a name is provided and if there are multiple nodes with the same name, - the first node with the name is returned. - - Args: - index_or_name: The index or name of the node. - - Returns: - The node if found. - - Raises: - IndexError: If the index is out of range. - ValueError: If the node with the given name is not found. - """ - # NOTE: This is a method specific to Graph, not required by the protocol unless proven - if isinstance(index_or_name, int): - return self[index_or_name] - for node in self: - if node.name == index_or_name: - return node - raise ValueError(f"Node with name '{index_or_name}' not found.") - - def num_nodes(self) -> int: - """Get the number of nodes in the graph in O(1) time. - - Note that this method returns the number of nodes this graph directly contains. - It does not count nodes in subgraphs. - - This is an alias for ``len(graph)``. Use this if you prefer a more descriptive - name for readability. - """ - # NOTE: This is a method specific to Graph, not required by the protocol unless proven - return len(self) - - # Mutation methods - def append(self, node: Node, /) -> None: - """Append a node to the graph in O(1) time. - - Unique names will be assigned to the node and its values if any name is ``None``. - - Args: - node: The node to append. - - Raises: - ValueError: If the node belongs to another graph. - """ - self._set_node_graph_to_self_and_assign_names(node) - self._nodes.append(node) - - def extend(self, nodes: Iterable[Node], /) -> None: - """Extend the graph with the given nodes in O(#new_nodes) time. - - Unique names will be assigned to the node and its values if any name is ``None``. - - Args: - nodes: The nodes to extend the graph with. - - Raises: - ValueError: If any node belongs to another graph. - """ - nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes] - self._nodes.extend(nodes) - - def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: - """Remove nodes from the graph in O(#num of nodes to remove) time. - - If any errors are raise, to ensure the graph is not left in an inconsistent state, - the graph is not modified. - - Args: - nodes: The node to remove. - safe: If True, performs the following actions before removal: - - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - 3. It removes references to all inputs so it is no longer a user of other nodes. - - Raises: - ValueError: If any node to remove does not belong to this graph. - ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. - ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. - """ - if not isinstance(nodes, Iterable): - nodes_set: AbstractSet[Node] = {nodes} - else: - nodes_set = frozenset(nodes) - graph_outputs = frozenset(self.outputs) - for node in nodes_set: - if node.graph is not self: - raise ValueError(f"The node '{node!r}' does not belong to this graph.") - if safe: - # Check 1, 2 - _check_node_safe_to_remove(node, nodes_set, graph_outputs) - for node in nodes_set: - if safe: - # 3. Detach from all inputs so that it is no longer a user of other nodes - for i in range(len(node.inputs)): - node.replace_input_with(i, None) - # Set attributes to remove the node from this graph - node.graph = None - self._nodes.remove(node) - - def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes after the given node in O(#new_nodes) time. - - Unique names will be assigned to the node and its values if any name is ``None``. - - Args: - node: The node to insert after. - new_nodes: The new nodes to insert. - - Raises: - ValueError: If any node belongs to another graph. - """ - if isinstance(new_nodes, Node): - new_nodes = (new_nodes,) - new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] - self._nodes.insert_after(node, new_nodes) - - def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes before the given node in O(#new_nodes) time. - - Unique names will be assigned to the node and its values if any name is ``None``. - - Args: - node: The node to insert before. - new_nodes: The new nodes to insert. - - Raises: - ValueError: If any node belongs to another graph. - """ - if isinstance(new_nodes, Node): - new_nodes = (new_nodes,) - new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes] - self._nodes.insert_before(node, new_nodes) - - def sort(self) -> None: - """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time. - - This sort is stable. It preserves the original order as much as possible. - - Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort - - Raises: - ValueError: If the graph contains a cycle, making topological sorting impossible. - """ - # Obtain all nodes from the graph and its subgraphs for sorting - nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) - # Store the sorted nodes of each subgraph - sorted_nodes_by_graph: dict[Graph, list[Node]] = { - graph: [] for graph in {node.graph for node in nodes if node.graph is not None} - } - # TODO(justinchuby): Explain why we need to store direct predecessors and children and why - # we only need to store the direct ones - - # The depth of a node is defined as the number of direct children it has - node_depth: dict[Node, int] = dict.fromkeys(nodes, 0) - # Direct predecessors of a node - node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes} - # Store the negative index of the nodes because heapq is a min heap and we - # want to pop the node with largest index value first, effectively turning - # it to a max heap - neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)} - - def add_predecessor(child: Node, predecessor: Node | None) -> None: - """Add a predecessor of a node, and increment the depth of the predecessor.""" - if predecessor is None: - return - node_predecessors[child].append(predecessor) - node_depth[predecessor] += 1 - - # 1. Build the direct predecessors of each node and the depth of each node - # for sorting topologically using Kahn's algorithm. - # Note that when a node contains graph attributes (aka. has subgraphs), - # we consider all nodes in the subgraphs *predecessors* of this node. This - # way we ensure the implicit dependencies of the subgraphs are captured - # as predecessors of the node. - for node in nodes: - # All producers of input values are considered as direct predecessors. - for input_value in node.inputs: - if input_value is None: - continue - predecessor_node = input_value.producer() - add_predecessor(node, predecessor_node) - # All nodes in attribute graphs are considered as direct predecessors. - for attr in node.attributes.values(): - if not isinstance(attr, Attr): - continue - # A nice thing about this algorithm is that we only need to record - # direct predecessors. This continues to be true even with subgraphs. - # When a node in a subgraph (a) contains its own subgraphs (b), the - # node in subgraphs (b) are guranteed to appear before the node - # in (a). - if attr.type == _enums.AttributeType.GRAPH: - for predecessor_node in attr.value: - add_predecessor(node, predecessor_node) - elif attr.type == _enums.AttributeType.GRAPHS: - for attribute_graph in attr.value: - for predecessor_node in attribute_graph: - add_predecessor(node, predecessor_node) - - # 2. Priority Queue: Track nodes with zero direct children in a priority queue, - # using NEGATIVE original index for ordering. - # This ensures nodes appearing LATER in the original order are processed EARLIER. - # We get REVERSED topological order of each subgraph. - priority_queue: list[tuple[int, Node]] = [ - (neg_node_index[node], node) for node in nodes if node_depth[node] == 0 - ] - heapq.heapify(priority_queue) - - # 3. Topological Sort: - num_of_sorted_nodes = 0 - while priority_queue: - # Pop the node with the most negative index and add it to the sorted nodes by subgraph. - _, current_node = heapq.heappop(priority_queue) - assert current_node.graph is not None - sorted_nodes_by_graph[current_node.graph].append(current_node) - num_of_sorted_nodes += 1 - # Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue. - for predecessor_node in node_predecessors[current_node]: - node_depth[predecessor_node] -= 1 - if node_depth[predecessor_node] == 0: - heapq.heappush( - priority_queue, (neg_node_index[predecessor_node], predecessor_node) - ) - - # 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle. - if num_of_sorted_nodes != len(nodes): - raise ValueError("Graph contains a cycle, topological sort is not possible.") - - # 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order. - for graph, sorted_nodes in sorted_nodes_by_graph.items(): - # The graph container ensures all the nodes are unique so we can safely extend - graph.extend(reversed(sorted_nodes)) - - # End of mutation methods - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - return _graph_str(self) - - def __repr__(self) -> str: - return _graph_repr(self) - - -def _graph_str(graph: Graph | GraphView) -> str: - """Return a string representation of the graph.""" - # TODO(justinchuby): Show docstrings and metadata - inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) - outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) - initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) - if initializers_text: - initializers_text = ( - "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," - ) - signature = f"""\ -graph( - name={graph.name or "anonymous_graph:" + str(id(graph))}, - inputs=({textwrap.indent(inputs_text, " " * 8)} - ), - outputs=({textwrap.indent(outputs_text, " " * 8)} - ),{textwrap.indent(initializers_text, " " * 4)} -)""" - node_count = len(graph) - number_width = len(str(node_count)) - node_lines = [] - for i, node in enumerate(graph): - node_name = node.name if node.name else f":anonymous_node:{id(node)}" - node_text = f"# {node_name}\n{node}" - indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) - # Remove the leading spaces - indented_node_text = indented_node_text.strip() - node_lines.append(f"{i:>{number_width}} | {indented_node_text}") - returns = ", ".join(str(x) for x in graph.outputs) - body = ( - "{\n" - + textwrap.indent("\n".join(node_lines), " " * 4) - + textwrap.indent(f"\nreturn {returns}", " " * 4) - + "\n}" - ) - - return f"{signature} {body}" - - -def _graph_repr(graph: Graph | GraphView) -> str: - """Return an repr string of the graph.""" - inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs) - outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs) - initializers_text = ",\n".join(str(x) for x in graph.initializers.values()) - if initializers_text: - initializers_text = ( - "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n)," - ) - return f"""\ -{graph.__class__.__name__}( - name={graph.name or "anonymous_graph:" + str(id(graph))!r}, - inputs=({textwrap.indent(inputs_text, " " * 8)} - ), - outputs=({textwrap.indent(outputs_text, " " * 8)} - ),{textwrap.indent(initializers_text, " " * 4)} - len()={len(graph)} -)""" - - -class GraphView(Sequence[Node], _display.PrettyPrintable): - """A read-only view on a graph. - - The GraphView is useful for analysis of a subgraph. It can be initialized - with a subset of nodes from a :class:`Graph`. Creating GraphView does not - change the ownership of the nodes, and so it is possible to create multiple - GraphViews that contain the same nodes. If the underlying nodes / connections - are mutated, the mutation will be reflected in all views as well. - - The graph view can be serialized to ONNX:: - - graph_proto = ir.serde.serialize_graph(graph_view) - - It can also be used to create a model:: - - model = ir.Model(graph_view, ir_version=8) - model_proto = ir.serde.serialize_model(model) - - The model created with a GraphView will have a fixed topology, and its graph - will remain read-only as a GraphView. No copying will be done during the - initialization process. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - __slots__ = ( - "_metadata", - "_metadata_props", - "doc_string", - "initializers", - "inputs", - "name", - "nodes", - "opset_imports", - "outputs", - ) - - def __init__( - self, - inputs: Sequence[Value], - outputs: Sequence[Value], - *, - nodes: Iterable[Node], - initializers: Sequence[_protocols.ValueProtocol] = (), - doc_string: str | None = None, - opset_imports: dict[str, int] | None = None, - name: str | None = None, - metadata_props: dict[str, str] | None = None, - ): - self.name = name - self.inputs = tuple(inputs) - self.outputs = tuple(outputs) - for initializer in initializers: - if initializer.name is None: - raise ValueError(f"Initializer must have a name: {initializer}") - self.initializers = {tensor.name: tensor for tensor in initializers} - self.doc_string = doc_string - self.opset_imports = opset_imports or {} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = metadata_props - self._nodes: tuple[Node, ...] = tuple(nodes) - - @typing.overload - def __getitem__(self, index: int) -> Node: ... - @typing.overload - def __getitem__(self, index: slice) -> Sequence[Node]: ... - - def __getitem__(self, index): - return self._nodes[index] - - def __len__(self) -> int: - return len(self._nodes) - - def __iter__(self) -> Iterator[Node]: - return iter(self._nodes) - - def __reversed__(self) -> Iterator[Node]: - return reversed(self._nodes) - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - return _graph_str(self) - - def __repr__(self) -> str: - return _graph_repr(self) - - -class Model(_protocols.ModelProtocol, _display.PrettyPrintable): - __slots__ = ( - "_functions", - "_metadata", - "_metadata_props", - "doc_string", - "domain", - "graph", - "ir_version", - "model_version", - "producer_name", - "producer_version", - ) - """IR Model. - - A model is a container for a graph and metadata. - - Attributes: - graph: The graph of the model. - ir_version: The version of the IR. - producer_name: The name of the producer. - producer_version: The version of the producer. - domain: The domain of the model. - model_version: The version of the model. - doc_string: Documentation string. - functions: The functions defined in the model. - metadata_props: Metadata. - """ - - def __init__( - self, - graph: Graph, - *, - ir_version: int, - producer_name: str | None = None, - producer_version: str | None = None, - domain: str | None = None, - model_version: int | None = None, - doc_string: str | None = None, - functions: Sequence[Function] = (), - meta_data_props: dict[str, str] | None = None, - ) -> None: - self.graph: Graph = graph - self.ir_version = ir_version - self.producer_name = producer_name - self.producer_version = producer_version - self.domain = domain - self.model_version = model_version - self.doc_string = doc_string - self._functions = {func.identifier(): func for func in functions} - self._metadata: _metadata.MetadataStore | None = None - self._metadata_props: dict[str, str] | None = meta_data_props - - @property - def functions(self) -> dict[_protocols.OperatorIdentifier, Function]: - return self._functions - - @property - def opset_imports(self) -> dict[str, int]: - return self.graph.opset_imports - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - if self._metadata is None: - self._metadata = _metadata.MetadataStore() - return self._metadata - - @property - def metadata_props(self) -> dict[str, str]: - if self._metadata_props is None: - self._metadata_props = {} - return self._metadata_props - - def __str__(self) -> str: - # TODO(justinchuby): Show docstrings and metadata - signature = f"""\ -< - ir_version={self.ir_version!r}, - opset_imports={self.opset_imports!r}, - producer_name={self.producer_name!r}, - producer_version={self.producer_version!r}, - domain={self.domain!r}, - model_version={self.model_version!r}, ->""" - graph_text = str(self.graph) - functions_text = "\n\n".join(str(func) for func in self.functions.values()) - return f"{signature}\n{graph_text}" + f"\n\n{functions_text}" - - def __repr__(self) -> str: - return f"""\ -Model( - ir_version={self.ir_version!r}, - opset_imports={self.opset_imports!r}, - producer_name={self.producer_name!r}, - producer_version={self.producer_version!r}, - domain={self.domain!r}, - model_version={self.model_version!r}, - functions={self.functions!r}, - graph={textwrap.indent(repr(self.graph), " " * 4).strip()} -)""" - - def graphs(self) -> Iterable[Graph]: - """Get all graphs and subgraphs in the model. - - This is a convenience method to traverse the model. Consider using - `onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced - traversals on nodes. - """ - # NOTE(justinchuby): Given - # (1) how useful the method is - # (2) I couldn't find an appropriate name for it in `traversal.py` - # (3) Users familiar with onnxruntime optimization tools expect this method - # I created this method as a core method instead of an iterator in - # `traversal.py`. - seen_graphs: set[Graph] = set() - for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph): - if node.graph is not None and node.graph not in seen_graphs: - seen_graphs.add(node.graph) - yield node.graph - - -class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable): - """IR functions. - - Like a graph, a function can have nodes that are not topologically sorted. It is - the responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Function. The Function can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(function)``. - - Attributes: - name: The function name. - domain: The domain this function is defined in. - overload: The overload name when the function is overloaded. - inputs: The input values of the function. - attributes: The attributes this function defines. - outputs: The output values of the function. - opset_imports: Opsets imported by the function. - doc_string: Documentation string. - meta: Metadata store for graph transform passes. - metadata_props: Metadata that will be serialized to the ONNX file. - """ - - __slots__ = ( - "_attributes", - "_domain", - "_graph", - "_name", - "_overload", - ) - - def __init__( - self, - domain: str, - name: str, - overload: str = "", - *, - # Ensure the inputs and outputs of the function belong to a graph - # and not from an outer scope - graph: Graph, - attributes: Sequence[Attr], - ) -> None: - self._domain = domain - self._name = name - self._overload = overload - self._graph = graph - self._attributes = OrderedDict((attr.name, attr) for attr in attributes) - - def identifier(self) -> _protocols.OperatorIdentifier: - return self.domain, self.name, self.overload - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = value - - @property - def domain(self) -> str: - return self._domain - - @domain.setter - def domain(self, value: str) -> None: - self._domain = _normalize_domain(value) - - @property - def overload(self) -> str: - return self._overload - - @overload.setter - def overload(self, value: str) -> None: - self._overload = value - - @property - def inputs(self) -> MutableSequence[Value]: - return self._graph.inputs - - @property - def outputs(self) -> MutableSequence[Value]: - return self._graph.outputs - - @property - def attributes(self) -> OrderedDict[str, Attr]: - return self._attributes - - @typing.overload - def __getitem__(self, index: int) -> Node: ... - @typing.overload - def __getitem__(self, index: slice) -> Sequence[Node]: ... - - def __getitem__(self, index): - return self._graph.__getitem__(index) - - def __len__(self) -> int: - return self._graph.__len__() - - def __iter__(self) -> Iterator[Node]: - return self._graph.__iter__() - - def __reversed__(self) -> Iterator[Node]: - return self._graph.__reversed__() - - @property - def doc_string(self) -> str | None: - return self._graph.doc_string - - @doc_string.setter - def doc_string(self, value: str | None) -> None: - self._graph.doc_string = value - - @property - def opset_imports(self) -> dict[str, int]: - return self._graph.opset_imports - - @property - def meta(self) -> _metadata.MetadataStore: - """The metadata store for intermediate analysis. - - Write to the :attr:`metadata_props` if you would like the metadata to be serialized - to the ONNX proto. - """ - return self._graph.meta - - @property - def metadata_props(self) -> dict[str, str]: - return self._graph.metadata_props - - # Mutation methods - def append(self, node: Node, /) -> None: - """Append a node to the function in O(1) time.""" - self._graph.append(node) - - def extend(self, nodes: Iterable[Node], /) -> None: - """Extend the function with the given nodes in O(#new_nodes) time.""" - self._graph.extend(nodes) - - def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None: - """Remove nodes from the graph in O(#num of nodes) time. - - If any errors are raise, to ensure the graph is not left in an inconsistent state, - the graph is not modified. - - Args: - nodes: The node to remove. - safe: If True, performs the following actions before removal: - - 1. It checks to make sure there are no users of the node that are not - to be removed before removing it. - 2. It checks the node does not contribute to any graph outputs. - 3. It removes references to all inputs so it is no longer a user of other nodes. - - Raises: - ValueError: If any node to remove does not belong to this graph. - ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node. - ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed. - """ - self._graph.remove(nodes, safe=safe) - - def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes after the given node in O(#new_nodes) time.""" - self._graph.insert_after(node, new_nodes) - - def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None: - """Insert new nodes before the given node in O(#new_nodes) time.""" - self._graph.insert_before(node, new_nodes) - - def sort(self) -> None: - """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.""" - self._graph.sort() - - # End of mutation methods - - def __str__(self) -> str: - full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "") - inputs_text = ",\n".join(str(x) for x in self.inputs) - outputs_text = ",\n".join(str(x) for x in self.outputs) - attributes_text = ",\n".join( - f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is not None) - for attr in self.attributes.values() - ) - if attributes_text: - attributes_text = ( - "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}" - ) - signature = f"""\ -< - opset_imports={self.opset_imports!r}, -> -def {full_name}( - inputs=( -{textwrap.indent(inputs_text, " " * 8)} - ),{textwrap.indent(attributes_text, " " * 4)} - outputs=( -{textwrap.indent(outputs_text, " " * 8)} - ), -)""" - node_count = len(self) - number_width = len(str(node_count)) - node_lines = [] - for i, node in enumerate(self): - node_name = node.name if node.name else f":anonymous_node:{id(node)}" - node_text = f"# {node_name}\n{node}" - indented_node_text = textwrap.indent(node_text, " " * (number_width + 4)) - # Remove the leading spaces - indented_node_text = indented_node_text.strip() - node_lines.append(f"{i:>{number_width}} | {indented_node_text}") - returns = ", ".join(str(x) for x in self.outputs) - body = ( - "{\n" - + textwrap.indent("\n".join(node_lines), " " * 4) - + textwrap.indent(f"\nreturn {returns}", " " * 4) - + "\n}" - ) - - return f"{signature} {body}" - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})" - - -class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable): - """Reference attribute.""" - - __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string") - - def __init__( - self, - name: str, - ref_attr_name: str, - type: _enums.AttributeType, - *, - doc_string: str | None = None, - ) -> None: - self._name = name - self._ref_attr_name = ref_attr_name - self._type = type - self.doc_string = doc_string - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = value - - @property - def ref_attr_name(self) -> str: - return self._ref_attr_name - - @ref_attr_name.setter - def ref_attr_name(self, value: str) -> None: - self._ref_attr_name = value - - @property - def type(self) -> _enums.AttributeType: - return self._type - - @type.setter - def type(self, value: _enums.AttributeType) -> None: - self._type = value - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})" - - -class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable): - """Base class for ONNX attributes.""" - - __slots__ = ("doc_string", "name", "type", "value") - - def __init__( - self, - name: str, - type: _enums.AttributeType, - value: Any, - *, - doc_string: str | None = None, - ): - self.name = name - self.type = type - self.value = value - self.doc_string = doc_string - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _protocols.AttributeProtocol): - return False - - if self.name != other.name: - return False - if self.type != other.type: - return False - if self.value != other.value: - return False - if self.doc_string != other.doc_string: - return False - return True - - def __str__(self) -> str: - if self.type == _enums.AttributeType.GRAPH: - return textwrap.indent("\n" + str(self.value), " " * 4) - return str(self.value) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})" - - # Well typed getters - def as_float(self) -> float: - """Get the attribute value as a float.""" - # Do not use isinstance check because it may prevent np.float32 etc. from being used - return float(self.value) - - def as_int(self) -> int: - """Get the attribute value as an int.""" - # Do not use isinstance check because it may prevent np.int32 etc. from being used - return int(self.value) - - def as_string(self) -> str: - """Get the attribute value as a string.""" - if not isinstance(self.value, str): - raise TypeError(f"Value of attribute '{self!r}' is not a string.") - return self.value - - def as_tensor(self) -> _protocols.TensorProtocol: - """Get the attribute value as a tensor.""" - if not isinstance(self.value, _protocols.TensorProtocol): - raise TypeError(f"Value of attribute '{self!r}' is not a tensor.") - return self.value - - def as_graph(self) -> Graph: - """Get the attribute value as a graph.""" - if not isinstance(self.value, Graph): - raise TypeError(f"Value of attribute '{self!r}' is not a graph.") - return self.value - - def as_floats(self) -> Sequence[float]: - """Get the attribute value as a sequence of floats.""" - if not isinstance(self.value, Sequence): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used - # Create a copy of the list to prevent mutation - return [float(v) for v in self.value] - - def as_ints(self) -> Sequence[int]: - """Get the attribute value as a sequence of ints.""" - if not isinstance(self.value, Sequence): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used - # Create a copy of the list to prevent mutation - return list(self.value) - - def as_strings(self) -> Sequence[str]: - """Get the attribute value as a sequence of strings.""" - if not isinstance(self.value, Sequence): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: - if not all(isinstance(x, str) for x in self.value): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.") - # Create a copy of the list to prevent mutation - return list(self.value) - - def as_tensors(self) -> Sequence[_protocols.TensorProtocol]: - """Get the attribute value as a sequence of tensors.""" - if not isinstance(self.value, Sequence): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: - if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.") - # Create a copy of the list to prevent mutation - return list(self.value) - - def as_graphs(self) -> Sequence[Graph]: - """Get the attribute value as a sequence of graphs.""" - if not isinstance(self.value, Sequence): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: - if not all(isinstance(x, Graph) for x in self.value): - raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.") - # Create a copy of the list to prevent mutation - return list(self.value) - - -# NOTE: The following functions are just for convenience -def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr: - """Create a float attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.FLOAT, - value, - doc_string=doc_string, - ) - - -def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr: - """Create an int attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.INT, - value, - doc_string=doc_string, - ) - - -def AttrString(name: str, value: str, doc_string: str | None = None) -> Attr: - """Create a str attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.STRING, - value, - doc_string=doc_string, - ) - - -def AttrTensor( - name: str, value: _protocols.TensorProtocol, doc_string: str | None = None -) -> Attr: - """Create a tensor attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.TENSOR, - value, - doc_string=doc_string, - ) - - -def AttrGraph(name: str, value: Graph, doc_string: str | None = None) -> Attr: - """Create a graph attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.GRAPH, - value, - doc_string=doc_string, - ) - - -def AttrFloat32s(name: str, value: Sequence[float], doc_string: str | None = None) -> Attr: - """Create a float sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.FLOATS, - value, - doc_string=doc_string, - ) - - -def AttrInt64s(name: str, value: Sequence[int], doc_string: str | None = None) -> Attr: - """Create an int sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.INTS, - value, - doc_string=doc_string, - ) - - -def AttrStrings(name: str, value: Sequence[str], doc_string: str | None = None) -> Attr: - """Create a string sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.STRINGS, - value, - doc_string=doc_string, - ) - - -def AttrTensors( - name: str, value: Sequence[_protocols.TensorProtocol], doc_string: str | None = None -) -> Attr: - """Create a tensor sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.TENSORS, - value, - doc_string=doc_string, - ) - - -def AttrGraphs(name: str, value: Sequence[Graph], doc_string: str | None = None) -> Attr: - """Create a graph sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.GRAPHS, - value, - doc_string=doc_string, - ) - - -# NOTE: SparseTensor should be a sparse tensor proto -def AttrSparseTensor( - name: str, value: _protocols.SparseTensorProtocol, doc_string: str | None = None -) -> Attr: - """Create a sparse tensor attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.SPARSE_TENSOR, - value, - doc_string=doc_string, - ) - - -def AttrSparseTensors( - name: str, value: Sequence[_protocols.SparseTensorProtocol], doc_string: str | None = None -) -> Attr: - """Create a sparse tensor sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.SPARSE_TENSORS, - value, - doc_string=doc_string, - ) - - -@dataclasses.dataclass -class TypeAndShape: - """Type and shape. - - Useful for constructing a type proto. - """ - - type: _protocols.TypeProtocol | None - shape: Shape | None - - -def AttrTypeProto(name: str, value: TypeAndShape, doc_string: str | None = None) -> Attr: - """Create a type attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.TYPE_PROTO, - value, - doc_string=doc_string, - ) - - -def AttrTypeProtos( - name: str, value: Sequence[TypeAndShape], doc_string: str | None = None -) -> Attr: - """Create a type sequence attribute.""" - # NOTE: The function name is capitalized to maintain API backward compatibility. - return Attr( - name, - _enums.AttributeType.TYPE_PROTOS, - value, - doc_string=doc_string, - ) diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py deleted file mode 100644 index 2af10646de..0000000000 --- a/onnxscript/ir/_core_test.py +++ /dev/null @@ -1,1732 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import copy -import pathlib -import tempfile -import unittest -from typing import Any - -import ml_dtypes -import numpy as np -import onnx -import onnx.external_data_helper -import parameterized -import torch - -from onnxscript import ir -from onnxscript.ir import _core - - -class TensorTest(unittest.TestCase): - def test_initialize(self): - tensor = _core.Tensor( - np.random.rand(1, 2).astype(np.float32), - dtype=ir.DataType.FLOAT, - shape=_core.Shape((1, 2)), - name="test", - ) - self.assertEqual(tensor.name, "test") - self.assertEqual(tensor.dtype, ir.DataType.FLOAT) - self.assertEqual(tensor.shape, _core.Shape((1, 2))) - np.testing.assert_array_equal(tensor, tensor) - - def test_init_raises_when_value_is_not_array(self): - with self.assertRaises(TypeError): - _core.Tensor(42) - - def test_init_requires_type_when_value_is_not_np_array(self): - torch_tensor = torch.tensor(42) - with self.assertRaises(ValueError): - _core.Tensor(torch_tensor) - - @parameterized.parameterized.expand( - [ - ("bfloat16", np.uint16, ir.DataType.BFLOAT16), - ( - "float8e4m3fn", - np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})), - ir.DataType.FLOAT8E4M3FN, - ), - ("float8e4m3fnuz", np.uint8, ir.DataType.FLOAT8E4M3FNUZ), - ("float8e5m2", np.uint8, ir.DataType.FLOAT8E5M2), - ("float8e5m2fnuz", np.uint8, ir.DataType.FLOAT8E5M2FNUZ), - ("int4", np.int8, ir.DataType.INT4), - ("int4_uint8", np.uint8, ir.DataType.INT4), - ("uint4", np.uint8, ir.DataType.UINT4), - ("float4e2m1", np.uint8, ir.DataType.FLOAT4E2M1), - ] - ) - def test_init_with_non_native_numpy_dtype(self, _: str, np_dtype, dtype: ir.DataType): - array = np.array([0b1, 0b11], dtype=np_dtype) - tensor = _core.Tensor(array, dtype=dtype) - self.assertEqual(tensor.dtype, dtype) - np.testing.assert_array_equal(tensor, array.view(dtype.numpy())) - - def test_initialize_with_just_np_array(self): - array = np.random.rand(1, 2) - tensor = _core.Tensor(array) - np.testing.assert_array_equal(tensor, array) - - def test_initialize_raises_when_numpy_dtype_doesnt_match(self): - array = np.random.rand(1, 2).astype(np.float32) - with self.assertRaises(TypeError): - _core.Tensor(array, dtype=ir.DataType.INT64) - - def test_initialize_supports_custom_dtype(self): - custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})) - array = np.random.rand(1, 2).astype(custom_dtype) - _core.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN) - - def test_initialize_raises_when_numpy_dtype_doesnt_match_custom_dtype(self): - custom_dtype = np.dtype((np.uint8, {"e4m3fn": (np.uint8, 0)})) - array = np.random.rand(1, 2).astype(custom_dtype) - with self.assertRaises(TypeError): - _core.Tensor(array, dtype=ir.DataType.BFLOAT16) - - def test_initialize_with_torch_tensor(self): - array = np.random.rand(1, 2).astype(np.int64) - np_tensor = _core.Tensor(array) - torch_tensor = _core.Tensor(torch.tensor(array), dtype=ir.DataType.INT64) - np.testing.assert_array_equal(torch_tensor, array) - np.testing.assert_array_equal(torch_tensor, np_tensor) - - def test_dlpack_np_to_torch(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - torch_tensor = torch.from_dlpack(tensor) - np.testing.assert_array_equal(torch_tensor, array) - - def test_dlpack_torch_to_np(self): - torch_tensor = torch.rand(1, 2) - tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) - array = np.from_dlpack(tensor) - np.testing.assert_array_equal(array, torch_tensor) - - def test_repr(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertIsInstance(repr(tensor), str) - - def test_dtype_returns_data_type_enum(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertEqual(tensor.dtype, ir.DataType.FLOAT) - - def test_shape(self): - tensor = _core.Tensor(np.random.rand(1, 2).astype(np.float32)) - self.assertEqual(tensor.shape, _core.Shape((1, 2))) - - def test_numpy_returns_np_array(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - np.testing.assert_equal(tensor.numpy(), array) - - def test_numpy_returns_data_when_dtype_is_not_supported(self): - array = np.array([1], dtype=np.uint8) - tensor = _core.Tensor(array, dtype=ir.DataType.INT4) - np.testing.assert_equal(tensor.numpy(), array) - - def test_tobytes(self): - array = np.random.rand(1, 2).astype(np.float32) - torch_tensor = torch.tensor(array) - tensor = _core.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) - self.assertEqual(tensor.tobytes(), array.tobytes()) - - def test_tobytes_returns_packed_data_for_int4(self): - array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=np.int8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.INT4) - self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - - def test_tobytes_returns_packed_data_for_int4_ml_dtypes(self): - array = np.array([-8, -1, 0, 1, 2, 7, 1], dtype=ml_dtypes.int4) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.INT4) - self.assertEqual(tensor.tobytes(), b"\xf8\x10r\x01") - - def test_tobytes_returns_packed_data_for_uint4(self): - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) - self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - - def test_tobytes_returns_packed_data_for_uint4_ml_dtypes(self): - array = np.array([0, 1, 2, 7, 15], dtype=ml_dtypes.uint4) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.UINT4) - self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - - def test_tobytes_returns_packed_data_for_float4e2m1(self): - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) - self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - - def test_tobytes_returns_packed_data_for_float4e2m1_ml_dtypes(self): - array = np.array([0, 1, 2, 7, 15], dtype=np.uint8) - # Test odd sized array - assert len(array) % 2 == 1 - tensor = _core.Tensor(array, dtype=ir.DataType.FLOAT4E2M1) - self.assertEqual(tensor.tobytes(), b"\x10r\x0f") - - def test_metadata(self): - array = np.random.rand(1, 2).astype(np.float32) - tensor = _core.Tensor(array) - tensor.meta["test"] = 1 - self.assertEqual(tensor.meta["test"], 1) - tensor.metadata_props["test"] = "any string" - self.assertEqual(tensor.metadata_props["test"], "any string") - - -def _to_external_tensor(tensor_proto, dir: str, filename: str): - onnx.external_data_helper.set_external_data(tensor_proto, location=filename) - path = pathlib.Path(dir) / filename - with open(path, "wb") as f: - f.write(tensor_proto.raw_data) - tensor_proto.ClearField("raw_data") - tensor_proto.data_location = onnx.TensorProto.EXTERNAL - - -class ExternalTensorTest(unittest.TestCase): - """Test the memory mapped external tensor class.""" - - def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with - self.external_data_name = "test_model.bin" - self.base_path = self.temp_dir.name - self.data = np.random.rand(2, 42).astype(np.float32) - self.data_float16 = np.random.rand(2, 42).astype(np.float16) - self.model = self._simple_model_with_external( - self.base_path, self.external_data_name, self.data - ) - - def tearDown(self) -> None: - self.temp_dir.cleanup() - - def _simple_model_with_external( - self, base_path: str, external_data_name: str, data: np.ndarray - ) -> onnx.ModelProto: - input = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [None]) - output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [None]) - raw_data = data.tobytes() - tensor = onnx.helper.make_tensor( - "input", onnx.TensorProto.FLOAT, data.shape, raw_data, raw=True - ) - raw_data2 = self.data_float16.tobytes() - tensor2 = onnx.helper.make_tensor( - "input2", onnx.TensorProto.FLOAT16, data.shape, raw_data2, raw=True - ) - onnx.external_data_helper.set_external_data( - tensor, external_data_name, offset=0, length=len(raw_data) - ) - onnx.external_data_helper.set_external_data( - tensor2, external_data_name, offset=len(raw_data), length=len(raw_data2) - ) - - node = onnx.helper.make_node("Identity", inputs=["input"], outputs=["output"]) - model = onnx.helper.make_model( - onnx.helper.make_graph( - [node], "test_graph", [input], [output], initializer=[tensor, tensor2] - ) - ) - tensor.ClearField("raw_data") - tensor2.ClearField("raw_data") - # Save the data to disk - with open(pathlib.Path(base_path) / external_data_name, "wb") as f: - f.write(raw_data) - f.write(raw_data2) - return model - - def test_initialize(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=ir.DataType.FLOAT, - base_dir=self.base_path, - name="input", - shape=_core.Shape(external_tensor.dims), - ) - self.assertEqual(tensor.dtype, ir.DataType.FLOAT) - np.testing.assert_equal(tensor, self.data) - # Ensure repeated reads are consistent - np.testing.assert_equal(tensor, self.data) - - def test_release_does_not_invalidate_tensor(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=ir.DataType.FLOAT, - base_dir=self.base_path, - name="input", - shape=_core.Shape(external_tensor.dims), - ) - self.assertEqual(tensor.dtype, ir.DataType.FLOAT) - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - # Release tensor - tensor.release() - self.assertEqual(tensor.raw, None) - # Tensor can be re-loaded after release - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - - def test_initialize_with_relative_path(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=ir.DataType.FLOAT, - name="input", - shape=_core.Shape(external_tensor.dims), - base_dir=pathlib.Path(self.base_path), - ) - self.assertEqual(tensor.dtype, ir.DataType.FLOAT) - np.testing.assert_equal(tensor, self.data) - # Ensure repeated reads are consistent - np.testing.assert_equal(tensor, self.data) - - def test_totypes_returns_correct_data_in(self): - external_tensor = self.model.graph.initializer[0] - external_info = onnx.external_data_helper.ExternalDataInfo(external_tensor) - tensor = _core.ExternalTensor( - external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=ir.DataType.FLOAT, - base_dir=self.base_path, - name="input", - shape=_core.Shape(external_tensor.dims), - ) - external_tensor2 = self.model.graph.initializer[1] - external_info2 = onnx.external_data_helper.ExternalDataInfo(external_tensor2) - tensor2 = _core.ExternalTensor( - external_info2.location, - offset=external_info2.offset, - length=external_info2.length, - dtype=ir.DataType.FLOAT16, - base_dir=self.base_path, - name="input", - shape=_core.Shape(external_tensor2.dims), - ) - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(tensor.tobytes(), self.data.tobytes()) - self.assertEqual(tensor2.tobytes(), self.data_float16.tobytes()) - - @parameterized.parameterized.expand( - [ - ("FLOAT", ir.DataType.FLOAT), - ("BOOL", ir.DataType.BOOL), - ("FLOAT16", ir.DataType.FLOAT16), - ("DOUBLE", ir.DataType.DOUBLE), - ] - ) - def test_external_tensor(self, _: str, dtype: ir.DataType): - expected_array = np.array( - [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]] - ).astype(dtype.numpy()) - tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype)) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - def test_external_tensor_bfloat16(self): - expected_array = np.array( - [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]] - ).astype(ml_dtypes.bfloat16) - tensor_proto = ir.serde.serialize_tensor( - ir.Tensor(expected_array.view(np.uint16), dtype=ir.DataType.BFLOAT16) - ) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal( - tensor.numpy().view(ml_dtypes.bfloat16), expected_array - ) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - @parameterized.parameterized.expand( - [ - ( - "FLOAT8E4M3FN", - ir.DataType.FLOAT8E4M3FN, - ml_dtypes.float8_e4m3fn, - ), - ( - "FLOAT8E4M3FNUZ", - ir.DataType.FLOAT8E4M3FNUZ, - ml_dtypes.float8_e4m3fnuz, - ), - ( - "FLOAT8E5M2", - ir.DataType.FLOAT8E5M2, - ml_dtypes.float8_e5m2, - ), - ( - "FLOAT8E5M2FNUZ", - ir.DataType.FLOAT8E5M2FNUZ, - ml_dtypes.float8_e5m2fnuz, - ), - ] - ) - def test_external_tensor_float8(self, _: str, dtype: ir.DataType, np_dtype): - expected_array = np.array( - [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]] - ).astype(np_dtype) - tensor_proto = ir.serde.serialize_tensor( - ir.Tensor(expected_array.view(np.uint8), dtype=dtype) - ) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy().view(np_dtype), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - @parameterized.parameterized.expand( - [ - ("INT8", ir.DataType.INT8), - ("INT16", ir.DataType.INT16), - ("INT32", ir.DataType.INT32), - ("INT64", ir.DataType.INT64), - ("INT4", ir.DataType.INT4), - ] - ) - def test_external_tensor_int(self, _: str, dtype: ir.DataType): - expected_array = np.array([[-8, 0, 1, 7]]).astype(dtype.numpy()) - tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype)) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - @parameterized.parameterized.expand( - [ - ("UINT8", ir.DataType.UINT8), - ("UINT16", ir.DataType.UINT16), - ("UINT32", ir.DataType.UINT32), - ("UINT64", ir.DataType.UINT64), - ("UINT4", ir.DataType.UINT4), - ] - ) - def test_external_tensor_uint(self, _: str, dtype: ir.DataType): - expected_array = np.array([[0, 1, 15]]).astype(dtype.numpy()) - tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array, dtype=dtype)) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - @parameterized.parameterized.expand( - [ - ("COMPLEX64", np.complex64), - ("COMPLEX128", np.complex128), - ] - ) - def test_external_tensor_complex(self, _: str, np_dtype: np.dtype): - expected_array = np.array([[0.0 + 1j, 0.2 - 1j, 0.3]], dtype=np_dtype) - tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array)) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - def test_external_tensor_float4e2m1(self): - expected_array = np.array([0, 1, 2, 7, 15]).view(ml_dtypes.float4_e2m1fn) - tensor_proto = ir.serde.serialize_tensor( - ir.Tensor(expected_array, dtype=ir.DataType.FLOAT4E2M1) - ) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - def test_external_tensor_empty_tensor(self): - expected_array = np.array([], dtype=np.float32) - tensor_proto = ir.serde.serialize_tensor(ir.Tensor(expected_array)) - with tempfile.TemporaryDirectory() as temp_dir: - _to_external_tensor(tensor_proto, temp_dir, "tensor.bin") - tensor = ir.serde.deserialize_tensor(tensor_proto, temp_dir) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - # Close the mmap file by deleting the reference to tensor so Windows doesn't complain - # about permission errors - del tensor - - -class SymbolicDimTest(unittest.TestCase): - def test_init_raises_when_value_is_int(self): - # Static dimensions should be python integers - with self.assertRaises(TypeError): - _core.SymbolicDim(42) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_equality_with_other_dimensions(self, _: str, value: Any): - dim1 = _core.SymbolicDim(value) - dim2 = _core.SymbolicDim(value) - self.assertEqual(dim1, dim2) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_equality_with_python_values(self, _: str, value: Any): - dim = _core.SymbolicDim(value) - self.assertEqual(dim, value) - self.assertIn(value, [dim]) - self.assertIn(dim, [value]) - - @parameterized.parameterized.expand([("str", "any string"), ("None", None)]) - def test_it_is_hashable(self, _: str, value: Any): - dim = _core.SymbolicDim(value) - self.assertEqual(hash(dim), hash(value)) - self.assertIn(dim, {dim}) - self.assertIn(dim, {value}) - - -class ShapeTest(unittest.TestCase): - def test_init_raises_when_denotations_and_dims_have_different_lengths(self): - with self.assertRaisesRegex(ValueError, "denotations"): - _core.Shape([42], ["DATA_CHANNEL", "BATCH"]) - - def test_int_dimensions_are_python_ints(self): - shape = _core.Shape([42]) - self.assertIsInstance(shape[0], int) - - def test_str_dimensions_are_symbolic_dims(self): - shape = _core.Shape(["any string"]) - self.assertIsInstance(shape[0], _core.SymbolicDim) - - def test_none_dimensions_are_symbolic_dims(self): - shape = _core.Shape([None]) - self.assertIsInstance(shape[0], _core.SymbolicDim) - - def test_init_raises_when_dims_is_not_a_list(self): - with self.assertRaises(TypeError): - _core.Shape(42) - - def test_init_converts_np_shape_to_tuple(self): - dims = np.array([42, 42]) - shape = _core.Shape(dims) - self.assertEqual(shape.dims, tuple(dims)) - - def test_init_converts_np_int_to_python_int(self): - dims = [np.int32(42)] - shape = _core.Shape(dims) - self.assertIsInstance(shape[0], int) - self.assertNotIsInstance(shape[0], np.int32) - self.assertIsInstance(shape.dims[0], int) - - @parameterized.parameterized.expand( - [ - ("empty", (), ()), - ("1d", (42,), (42,)), - ("int", (42, 42), (42, 42)), - ("str", ("any string", "any string"), ("any string", "any string")), - ("None", (None, None), (None, None)), - ] - ) - def test_eq_with_other_shapes( - self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...] - ): - shape_1 = _core.Shape(dims_1) - shape_2 = _core.Shape(dims_2) - self.assertEqual(shape_1, shape_2) - - @parameterized.parameterized.expand( - [ - ("empty", ()), - ("1d", (42,)), - ("int", (42, 42)), - ("str", ("any string", "any string")), - ("None", (None, None)), - ] - ) - def test_eq_with_tuple(self, _: str, dims: tuple[Any, ...]): - shape = _core.Shape(dims) - self.assertEqual(shape, dims) - - @parameterized.parameterized.expand( - [ - ("empty", []), - ( - "1d", - [ - 42, - ], - ), - ("int", [42, 42]), - ("str", ["any string", "any string"]), - ("None", [None, None]), - ] - ) - def test_eq_with_list(self, _: str, dims: list[Any]): - shape = _core.Shape(dims) - self.assertEqual(shape, dims) - - def test_eq_with_np_shape(self): - dims = (42,) - array = np.zeros(dims) - shape = _core.Shape(dims) - self.assertEqual(shape, array.shape) - - @parameterized.parameterized.expand( - [ - ("empty", (), (1,)), - ("d", (42,), (0,)), - ("rank", (42, 42), (42, 42, 42)), - ("str", ("any string",), (42,)), - ("None", (None, None), (None, 42)), - ] - ) - def test_ne_with_other_shapes( - self, _: str, dims_1: tuple[Any, ...], dims_2: tuple[Any, ...] - ): - shape_1 = _core.Shape(dims_1) - shape_2 = _core.Shape(dims_2) - self.assertNotEqual(shape_1, shape_2) - - def test_ne_with_random_object(self): - shape = _core.Shape((42,)) - self.assertNotEqual(shape, 42) - - def test_setitem_raises_when_shape_is_frozen(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True) - with self.assertRaisesRegex(TypeError, "frozen"): - shape[0] = 1 - - with self.assertRaisesRegex(TypeError, "frozen"): - shape[0] = "some_string" - - def test_getitem(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) - self.assertEqual(shape[0], 42) - - def test_getitem_accepts_a_slice(self): - shape = _core.Shape([1, 2, 3, 4]) - self.assertEqual(shape[1:3], (2, 3)) - - @parameterized.parameterized.expand( - [ - ("int", 42), - ("str", "any string"), - ("None", None), - ("SymbolicDim", _core.SymbolicDim("any string")), - ] - ) - def test_setitem(self, _: str, value): - shape = _core.Shape([0]) - shape[0] = value - dim = shape[0] - if isinstance(dim, _core.SymbolicDim): - self.assertEqual(dim.value, value) - else: - self.assertEqual(dim, value) - - def test_len(self): - shape = _core.Shape([42, "any string"]) - self.assertEqual(len(shape), 2) - - def test_get_denotation(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",)) - self.assertEqual(shape.get_denotation(0), "DATA_CHANNEL") - - def test_set_denotation(self): - shape = _core.Shape([42, 0], ["DATA_CHANNEL", "BATCH"]) - shape.set_denotation(1, "UPDATED") - self.assertEqual(shape.get_denotation(1), "UPDATED") - - def test_set_denotation_is_still_possible_when_shape_is_frozen(self): - shape = _core.Shape([42], denotations=("DATA_CHANNEL",), frozen=True) - shape.set_denotation(0, "UPDATED") - self.assertEqual(shape.get_denotation(0), "UPDATED") - - def test_is_static(self): - dim_from_numpy = np.array([42]).shape[0] - np_int = np.int32(42) - shape = _core.Shape([42, "any string", dim_from_numpy, np_int]) - self.assertTrue(shape.is_static(0)) - self.assertFalse(shape.is_static(1)) - self.assertTrue(shape.is_static(2)) - self.assertTrue(shape.is_static(3)) - self.assertFalse(shape.is_static()) - - def test_is_static_raises_when_index_out_of_range(self): - shape = _core.Shape([42]) - with self.assertRaises(IndexError): - shape.is_static(1) - - def test_is_static_on_whole_shape(self): - shape = _core.Shape([42, "any string"]) - self.assertFalse(shape.is_static()) - shape = _core.Shape([42, 42]) - self.assertTrue(shape.is_static()) - - def test_is_static_on_empty_shape(self): - shape = _core.Shape(()) - self.assertTrue(shape.is_static()) - - def test_is_dynamic(self): - dim_from_numpy = np.array([42]).shape[0] - np_int = np.int32(42) - shape = _core.Shape([42, "any string", dim_from_numpy, np_int]) - self.assertFalse(shape.is_dynamic(0)) - self.assertTrue(shape.is_dynamic(1)) - self.assertFalse(shape.is_dynamic(2)) - self.assertFalse(shape.is_dynamic(3)) - self.assertTrue(shape.is_dynamic()) - - def test_is_dynamic_raises_when_index_out_of_range(self): - shape = _core.Shape([42]) - with self.assertRaises(IndexError): - shape.is_dynamic(1) - - def test_is_dynamic_on_whole_shape(self): - shape = _core.Shape([42, "any string"]) - self.assertTrue(shape.is_dynamic()) - shape = _core.Shape([42, 42]) - self.assertFalse(shape.is_dynamic()) - - def test_is_dynamic_on_empty_shape(self): - shape = _core.Shape(()) - self.assertFalse(shape.is_dynamic()) - - -class ValueTest(unittest.TestCase): - def setUp(self) -> None: - self.v0 = _core.Value(name="v0") - self.v1 = _core.Value(name="v1") - self.node = _core.Node( - "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2 - ) - - def test_initialize(self): - _ = _core.Value() - - def test_it_is_hashable(self): - value = _core.Value() - self.assertIsInstance(hash(value), int) - self.assertIn(value, {value}) - - def test_meta(self): - value = _core.Value() - value.meta["test"] = 1 - self.assertEqual(value.meta["test"], 1) - value.metadata_props["test"] = "any string" - self.assertEqual(value.metadata_props["test"], "any string") - - def test_producer(self): - self.assertEqual(self.v0.producer(), None) - self.assertEqual(self.v1.producer(), None) - self.assertEqual(self.node.outputs[0].producer(), self.node) - self.assertEqual(self.node.outputs[1].producer(), self.node) - - def test_consumers(self): - self.assertEqual(self.v0.consumers(), (self.node,)) - self.assertEqual(self.v1.consumers(), (self.node,)) - self.assertEqual(self.node.outputs[0].consumers(), ()) - self.assertEqual(self.node.outputs[1].consumers(), ()) - - # TODO(justinchuby): Test all methods - - -class NodeTest(unittest.TestCase): - def setUp(self) -> None: - self.v0 = _core.Value(name="v0") - self.v1 = _core.Value(name="v1") - self.node = _core.Node( - "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3 - ) - self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]]) - self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs) - - def test_it_is_hashable(self): - self.assertIsInstance(hash(self.node), int) - self.assertIn(self.node, {self.node}) - - def test_init_with_values(self): - self.assertEqual(self.node.domain, "test") - self.assertEqual(self.node.op_type, "TestOp") - self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1)) - self.assertEqual(len(self.node.outputs), 3) - self.assertEqual(self.node.attributes, {}) - - def test_init_with_preinitialized_outputs(self): - out_1 = _core.Value( - name="out_1", - shape=_core.Shape([1]), - type=_core.TensorType(ir.DataType.BFLOAT16), - ) - out_2 = _core.Value( - name="out_2", - shape=_core.Shape([2]), - type=_core.TensorType(ir.DataType.INT4), - ) - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[out_1, out_2]) - self.assertEqual(node.outputs[0].name, "out_1") - self.assertEqual(node.outputs[0].shape, _core.Shape([1])) - self.assertEqual(node.outputs[0].dtype, ir.DataType.BFLOAT16) - self.assertEqual(node.outputs[1].name, "out_2") - self.assertEqual(node.outputs[1].shape, _core.Shape([2])) - self.assertEqual(node.outputs[1].dtype, ir.DataType.INT4) - self.assertIs(node.outputs[0], out_1) - self.assertIs(node.outputs[1], out_2) - self.assertIs(node.outputs[0].producer(), node) - self.assertIs(node.outputs[1].producer(), node) - self.assertIs(node.outputs[0].index(), 0) - self.assertIs(node.outputs[1].index(), 1) - - def test_init_raises_when_num_outputs_does_not_match_outputs(self): - with self.assertRaisesRegex(ValueError, "outputs"): - _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=2, outputs=[]) - - def test_init_with_zero_num_outputs(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=0) - self.assertEqual(node.outputs, ()) - - def test_init_with_empty_outputs(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), outputs=[]) - self.assertEqual(node.outputs, ()) - - def test_init_produces_one_output_with_unspecified_output_argument(self): - node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1)) - self.assertEqual(len(node.outputs), 1) - - def test_metadata(self): - self.node.meta["test"] = 1 - self.assertEqual(self.node.meta["test"], 1) - self.node.metadata_props["test"] = "any string" - self.assertEqual(self.node.metadata_props["test"], "any string") - - def test_it_is_added_to_a_graph_if_specified(self): - graph = _core.Graph( - (self.v0, self.v1), # type: ignore - self.node.outputs, - nodes=(self.node,), - ) - self.assertIn(self.node, graph) - - def test_predecessors(self): - self.assertEqual(self.node.predecessors(), ()) - self.assertEqual(self.node_a.predecessors(), (self.node,)) - self.assertEqual(self.node_b.predecessors(), (self.node,)) - - def test_predecessors_are_unique(self): - # node_b has three inputs from node, but only one predecessor - self.assertEqual(self.node_b.predecessors(), (self.node,)) - - def test_successors(self): - self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) - self.assertEqual(self.node_a.successors(), ()) - self.assertEqual(self.node_b.successors(), ()) - - def test_successors_are_unique(self): - self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) - - def test_domain_normalizes_ai_onnx(self): - # Node domain is always normalized to "" if it is "ai.onnx" - node = _core.Node("ai.onnx", "TestOp", inputs=()) - self.assertEqual(node.domain, "") - - node.domain = "" - self.assertEqual(node.domain, "") - - node.domain = "ai.onnx" - self.assertEqual(node.domain, "") - - # TODO(justinchuby): Test all methods - - -class GraphTest(unittest.TestCase): - def setUp(self) -> None: - self.v0 = _core.Value(name="v0") - self.v1 = _core.Value(name="v1") - self.node = _core.Node( - "", "Add", inputs=(self.v0, self.v1), num_outputs=1, name="node_add" - ) - self.graph = _core.Graph( - (self.v0, self.v1), - self.node.outputs, - nodes=(self.node,), - opset_imports={"": 1}, - ) - - def test_initialize(self): - self.assertEqual(self.graph.inputs, [self.v0, self.v1]) - self.assertEqual(self.graph.outputs, [*self.node.outputs]) - self.assertEqual(self.graph.opset_imports, {"": 1}) - self.assertEqual(self.graph.initializers, {}) - self.assertIsNone(self.graph.doc_string) - - def test_it_is_hashable(self): - self.assertIsInstance(hash(self.graph), int) - self.assertIn(self.graph, {self.graph}) - - def test_it_is_iterable_of_nodes(self): - self.assertEqual(list(self.graph), [self.node]) - - def test_node_returns_node_by_name(self): - self.assertIs(self.graph.node("node_add"), self.node) - - def test_node_returns_node_by_index(self): - self.assertIs(self.graph.node(0), self.node) - - def test_node_raises_when_node_does_not_exist(self): - with self.assertRaisesRegex(ValueError, "not found"): - self.graph.node("non_existent") - - def test_node_raises_when_index_out_of_range(self): - with self.assertRaises(IndexError): - self.graph.node(1) - - def test_num_nodes_returns_the_count_of_nodes(self): - self.assertEqual(self.graph.num_nodes(), 1) - self.assertEqual(self.graph.num_nodes(), len(self.graph)) - - def test_metadata(self): - self.graph.meta["test"] = 1 - self.assertEqual(self.graph.meta["test"], 1) - self.graph.metadata_props["test"] = "any string" - self.assertEqual(self.graph.metadata_props["test"], "any string") - - def test_remove_removes_node_from_graph(self): - self.graph.remove(self.node) - self.assertEqual(list(self.graph), []) - self.assertIsNone(self.node.graph) - - def test_remove_does_not_change_input_users(self): - self.graph.remove(self.node) - self.assertEqual(tuple(self.v0.uses()), ((self.node, 0),)) - self.assertEqual(tuple(self.v1.uses()), ((self.node, 1),)) - - def test_remove_does_not_change_graph_in_out(self): - self.graph.remove(self.node) - self.assertEqual(self.graph.inputs, [self.v0, self.v1]) - self.assertEqual(self.graph.outputs, list(self.node.outputs)) - - def test_remove_raises_when_node_does_not_belong_to_graph(self): - node = _core.Node("", "Add", inputs=(self.v0, self.v1), num_outputs=1) - with self.assertRaisesRegex(ValueError, "graph"): - self.graph.remove(node) - - def test_remove_safe_raises_when_node_output_is_graph_output(self): - with self.assertRaisesRegex(ValueError, "output"): - self.graph.remove(self.node, safe=True) - - def test_remove_safe_raises_when_node_has_users(self): - v0 = _core.Value(name="v0") - v1 = _core.Value(name="v1") - add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1) - identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1) - graph = _core.Graph( - (v0, v1), - identity_node.outputs, - nodes=(add_node, identity_node), - opset_imports={"": 1}, - ) - with self.assertRaisesRegex(ValueError, "used by other nodes"): - graph.remove(add_node, safe=True) - - def test_remove_safe_removes_uses_of_removed_nodes(self): - v0 = _core.Value(name="v0") - v1 = _core.Value(name="v1") - add_node = _core.Node("", "Add", inputs=(v0, v1), num_outputs=1) - identity_node = _core.Node("", "Identity", inputs=add_node.outputs, num_outputs=1) - graph = _core.Graph( - (v0, v1), - identity_node.outputs, - nodes=(add_node, identity_node), - opset_imports={"": 1}, - ) - # Remove add_node and check that it is no longer a consumer of v0 and v1 - sub_node = _core.Node("", "Sub", inputs=(v0, v1), num_outputs=1) - identity_node.replace_input_with(0, sub_node.outputs[0]) - graph.insert_before(identity_node, sub_node) - graph.remove(add_node, safe=True) - self.assertEqual(tuple(v0.uses()), ((sub_node, 0),)) - self.assertEqual(tuple(v1.uses()), ((sub_node, 1),)) - self.assertEqual(tuple(graph), (sub_node, identity_node)) - self.assertEqual(add_node.inputs, (None, None)) - - def test_register_initializer(self): - self.v1.const_value = ir.tensor([1, 2, 3]) - self.graph.register_initializer(self.v1) - self.assertEqual(self.graph.initializers, {self.v1.name: self.v1}) - - def test_register_initializer_raises_when_value_is_not_constant(self): - with self.assertRaises(ValueError): - self.graph.register_initializer(self.v0) - - def test_register_initializer_raises_when_a_different_value_is_already_registered(self): - self.v1.const_value = ir.tensor([1, 2, 3]) - self.graph.register_initializer(self.v1) - # This is fine - self.graph.register_initializer(self.v1) - self.v0.name = "v1" - with self.assertRaisesRegex(ValueError, "already registered"): - # Registering a different value with the same name should raise - self.graph.register_initializer(self.v0) - - def test_register_initializer_raises_when_value_does_not_have_a_name(self): - self.v1.name = None - with self.assertRaises(ValueError): - self.graph.register_initializer(self.v1) - - # TODO(justinchuby): Test graph mutation methods - - # Test topological sort. - # Graph structure: - # nodes: [node, ...] - # edges: [(predecessor_node, successor_node), ...] - # subgraphs: {node: [subgraph, ...]} - - def test_topological_sort_empty_graph(self): - graph = _core.Graph( - inputs=(), - outputs=(), - nodes=(), - ) - graph.sort() - self.assertEqual(tuple(graph), ()) - - def test_topological_sort_linear_dependencies(self): - # nodes=[1,2,3], edges=[(1,2),(2,3)] - v0 = _core.Value(name="v0") - node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) - node2 = _core.Node("", "Node2", inputs=(node1.outputs[0],), num_outputs=1) - node3 = _core.Node("", "Node3", inputs=(node2.outputs[0],), num_outputs=1) - graph = _core.Graph( - (v0,), - node3.outputs, - nodes=(node3, node2, node1), - ) - graph.sort() - sorted_nodes = tuple(graph) - expected_order = (node1, node2, node3) - self.assertEqual(sorted_nodes, expected_order) - - def test_topological_sort_independent_subgraphs(self): - # nodes=[1,2,3,4], edges=[(1,3),(2,4)] - v0 = _core.Value(name="v0") - v1 = _core.Value(name="v1") - node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) - node2 = _core.Node("", "Node2", inputs=(v1,), num_outputs=1) - node3 = _core.Node("", "Node3", inputs=(node1.outputs[0],), num_outputs=1) - node4 = _core.Node("", "Node4", inputs=(node2.outputs[0],), num_outputs=1) - graph = _core.Graph( - (v0, v1), - (node3.outputs[0], node4.outputs[0]), - nodes=(node4, node3, node2, node1), - ) - graph.sort() - sorted_nodes = tuple(graph) - expected_order = (node2, node4, node1, node3) - self.assertEqual(sorted_nodes, expected_order) - - def test_topological_sort_shared_successor(self): - # nodes=[1,2,3], edges=[(1,3),(2,3)] - v0 = _core.Value(name="v0") - node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) - node2 = _core.Node("", "Node2", inputs=(v0,), num_outputs=1) - node3 = _core.Node( - "", "Node3", inputs=(node1.outputs[0], node2.outputs[0]), num_outputs=1 - ) - graph = _core.Graph( - (v0,), - (node3.outputs[0],), - nodes=(node3, node2, node1), - ) - graph.sort() - sorted_nodes = tuple(graph) - expected_order = (node2, node1, node3) - self.assertEqual(sorted_nodes, expected_order) - - def _create_shared_predecessor_nodes( - self, - ) -> tuple[_core.Value, tuple[_core.Node, _core.Node, _core.Node]]: - # nodes=[0,1,2], edges=[(0,1),(0,2)] - v0 = _core.Value(name="v0") - node0 = _core.Node("", "Node0", inputs=(v0,), num_outputs=1) - node1 = _core.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1) - node2 = _core.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1) - return v0, (node0, node1, node2) - - @parameterized.parameterized.expand( - [ - ("012", (0, 1, 2), (0, 1, 2)), - ("021", (0, 2, 1), (0, 2, 1)), - ("102", (1, 0, 2), (0, 1, 2)), - ("120", (1, 2, 0), (0, 1, 2)), - ("201", (2, 0, 1), (0, 2, 1)), - ("210", (2, 1, 0), (0, 2, 1)), - ] - ) - def test_topological_sort_shared_predecessor( - self, _: str, initial_order: tuple[int], expected_order: tuple[int] - ): - v0, nodes = self._create_shared_predecessor_nodes() - graph = _core.Graph((v0,), (), nodes=[nodes[i] for i in initial_order]) - graph.sort() - sorted_nodes = list(graph) - self.assertEqual(sorted_nodes, [nodes[i] for i in expected_order]) - - def test_topological_sort_cycle_detection(self): - # nodes=[1,2,3], edges=[(1,2),(2,3),(3,2)] - v0 = _core.Value(name="v0") - node1 = _core.Node("", "Node1", inputs=(v0,), num_outputs=1) - node2 = _core.Node("", "Node2", inputs=(node1.outputs[0], v0), num_outputs=1) - node3 = _core.Node("", "Node3", inputs=(node2.outputs[0],), num_outputs=1) - node2.replace_input_with(1, node3.outputs[0]) - graph = _core.Graph( - (v0,), - (node3.outputs[0],), - nodes=(node1, node2, node3), - ) - with self.assertRaises(ValueError): - graph.sort() - - def test_topological_sort_subgraph(self): - # main_graph: nodes=[a,b,c,d,>,if], edges=[(a,>),(b,>),(>,if)], subgraphs={if:[then_graph,else_graph]} - # then_graph: nodes=[sub], edges=[(c,sub),(d,sub)] - # else_graph: nodes=[add], edges=[(c,add),(d,add)] - v0 = _core.Value(name="va") - v1 = _core.Value(name="vb") - v2 = _core.Value(name="vc") - v3 = _core.Value(name="vd") - node0 = _core.Node("", "a", inputs=(v0,), num_outputs=1) - node1 = _core.Node("", "b", inputs=(v1,), num_outputs=1) - node2 = _core.Node("", "c", inputs=(v2,), num_outputs=1) - node3 = _core.Node("", "d", inputs=(v3,), num_outputs=1) - node4 = _core.Node( - "", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 - ) - node5 = _core.Node( - "", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 - ) - node6 = _core.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) - then_graph = _core.Graph( - inputs=(), - outputs=(node4.outputs[0],), - nodes=(node4,), - name="then_graph", - ) - else_graph = _core.Graph( - inputs=(), - outputs=(node5.outputs[0],), - nodes=(node5,), - name="else_graph", - ) - node7 = _core.Node( - "", - "if", - inputs=(node6.outputs[0],), - num_outputs=1, - attributes=[ - ir.AttrGraph("then_branch", then_graph), - ir.AttrGraph("else_branch", else_graph), - ], - ) - main_graph_rev = _core.Graph( - inputs=(v0, v1, v2, v3), - outputs=(node7.outputs[0],), - nodes=(node7, node6, node3, node2, node1, node0), # if, >, d, c, b, a - name="main_graph_rev", - ) - main_graph_rev.sort() - self.assertEqual( - tuple(node.op_type for node in tuple(main_graph_rev)), - ("d", "c", "b", "a", ">", "if"), - ) - - -class GraphContainersTest(unittest.TestCase): - """Test containers for input, output and initializers of a graph.""" - - def setUp(self): - self.graph = _core.Graph(inputs=(), outputs=(), nodes=()) - self.value1 = _core.Value(name="input1") - self.value2 = _core.Value(name="output1") - self.value3 = _core.Value(name="initializer1", const_value=ir.tensor([1, 2, 3])) - - def test_initialize(self): - graph = _core.Graph( - inputs=(self.value1,), - outputs=(self.value2,), - nodes=(), - initializers=(self.value3,), - ) - self.assertEqual(graph.inputs, [self.value1]) - self.assertTrue(self.value1.is_graph_input()) - self.assertIs(self.value1.graph, graph) - self.assertFalse(self.value1.is_graph_output()) - self.assertFalse(self.value1.is_initializer()) - self.assertEqual(graph.outputs, [self.value2]) - self.assertTrue(self.value2.is_graph_output()) - self.assertIs(self.value2.graph, graph) - self.assertFalse(self.value2.is_graph_input()) - self.assertFalse(self.value2.is_initializer()) - self.assertEqual(graph.initializers, {self.value3.name: self.value3}) - self.assertTrue(self.value3.is_initializer()) - self.assertIs(self.value3.graph, graph) - self.assertFalse(self.value3.is_graph_input()) - self.assertFalse(self.value3.is_graph_output()) - - def test_append_to_inputs(self): - self.graph.inputs.append(self.value1) - self.assertIn(self.value1, self.graph.inputs) - self.assertTrue(self.value1.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - self.assertFalse(self.value1.is_graph_output()) - self.assertFalse(self.value1.is_initializer()) - - def test_append_input_raises_when_input_belongs_to_another_graph(self): - other_graph = _core.Graph(inputs=(), outputs=(), nodes=()) - other_graph.inputs.append(self.value1) - with self.assertRaisesRegex(ValueError, "is already owned by a different graph"): - self.graph.inputs.append(self.value1) - # Append is ok after the value is removed from the old graph - other_graph.inputs.clear() - self.graph.inputs.append(self.value1) - self.assertTrue(self.value1.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - - def test_extend_inputs(self): - self.graph.inputs.extend([self.value1, self.value2]) - self.assertIn(self.value1, self.graph.inputs) - self.assertIn(self.value2, self.graph.inputs) - self.assertTrue(self.value1.is_graph_input()) - self.assertTrue(self.value2.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - self.assertIs(self.value2.graph, self.graph) - - def test_pop_from_inputs(self): - self.graph.inputs.append(self.value1) - popped = self.graph.inputs.pop() - self.assertIs(popped, self.value1) - self.assertNotIn(self.value1, self.graph.inputs) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - - def test_pop_from_duplicated_inputs(self): - self.graph.inputs.extend([self.value1, self.value1]) - popped = self.graph.inputs.pop() - self.assertIs(popped, self.value1) - self.assertIn(self.value1, self.graph.inputs) - self.assertTrue(self.value1.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - - def test_pop_from_inputs_raises_when_empty(self): - with self.assertRaises(IndexError): - self.graph.inputs.pop() - - def test_insert_into_inputs(self): - self.graph.inputs.insert(0, self.value1) - self.assertIs(self.graph.inputs[0], self.value1) - self.assertTrue(self.value1.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - - def test_remove_from_inputs(self): - self.graph.inputs.append(self.value1) - self.graph.inputs.remove(self.value1) - self.assertNotIn(self.value1, self.graph.inputs) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - - def test_clear_inputs(self): - self.graph.inputs.extend([self.value1, self.value2]) - self.graph.inputs.clear() - self.assertEqual(len(self.graph.inputs), 0) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - self.assertFalse(self.value2.is_graph_input()) - self.assertIsNone(self.value2.graph) - - def test_clear_duplicated_inputs(self): - self.graph.inputs.extend([self.value1, self.value1]) - self.graph.inputs.clear() - self.assertEqual(len(self.graph.inputs), 0) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - - def test_inputs_set_items(self): - self.graph.inputs.append(self.value1) - self.graph.inputs[-1] = self.value2 - self.assertNotIn(self.value1, self.graph.inputs) - self.assertIn(self.value2, self.graph.inputs) - self.assertIs(self.graph.inputs[0], self.value2) - self.assertTrue(self.value2.is_graph_input()) - self.assertIs(self.value2.graph, self.graph) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - - def test_inputs_set_items_slices(self): - self.graph.inputs.extend([self.value1, self.value2]) - # Replace with one existing and one new input - self.graph.inputs[0:2] = [self.value2, self.value3] - self.assertNotIn(self.value1, self.graph.inputs) - self.assertIn(self.value2, self.graph.inputs) - self.assertIn(self.value3, self.graph.inputs) - self.assertIs(self.value2.graph, self.graph) - self.assertIs(self.value3.graph, self.graph) - self.assertTrue(self.value2.is_graph_input()) - self.assertTrue(self.value3.is_graph_input()) - self.assertFalse(self.value1.is_graph_input()) - self.assertIsNone(self.value1.graph) - - def test_take_inputs(self): - self.graph.inputs.extend([self.value1, self.value2, self.value3]) - inputs = self.graph.inputs[:2] - self.graph.inputs.clear() - self.graph.inputs.extend(inputs) - self.assertEqual(len(self.graph.inputs), 2) - self.assertEqual(self.graph.inputs, [self.value1, self.value2]) - self.assertTrue(self.value1.is_graph_input()) - self.assertTrue(self.value2.is_graph_input()) - self.assertFalse(self.value3.is_graph_input()) - self.assertIs(self.value1.graph, self.graph) - self.assertIs(self.value2.graph, self.graph) - self.assertIsNone(self.value3.graph) - - def test_append_to_outputs(self): - self.graph.outputs.append(self.value2) - self.assertIn(self.value2, self.graph.outputs) - self.assertTrue(self.value2.is_graph_output()) - - def test_append_output_raises_when_output_belongs_to_another_graph(self): - other_graph = _core.Graph(inputs=(), outputs=(), nodes=()) - other_graph.outputs.append(self.value2) - with self.assertRaisesRegex(ValueError, "is already an output of a different graph"): - self.graph.outputs.append(self.value2) - # Append is ok after the value is removed from the old graph - other_graph.outputs.clear() - self.graph.outputs.append(self.value2) - self.assertTrue(self.value2.is_graph_output()) - self.assertIs(self.value2.graph, self.graph) - - def test_extend_outputs(self): - self.graph.outputs.extend([self.value1, self.value2]) - self.assertIn(self.value1, self.graph.outputs) - self.assertIn(self.value2, self.graph.outputs) - - def test_pop_from_outputs(self): - self.graph.outputs.append(self.value2) - popped = self.graph.outputs.pop() - self.assertIs(popped, self.value2) - self.assertNotIn(self.value2, self.graph.outputs) - self.assertFalse(self.value2.is_graph_output()) - self.assertIsNone(self.value2.graph) - - def test_pop_from_duplicated_outputs(self): - self.graph.outputs.extend([self.value1, self.value1]) - popped = self.graph.outputs.pop() - self.assertIs(popped, self.value1) - self.assertIn(self.value1, self.graph.outputs) - self.assertTrue(self.value1.is_graph_output()) - self.assertIs(self.value1.graph, self.graph) - - def test_pop_from_outputs_raises_when_empty(self): - with self.assertRaises(IndexError): - self.graph.outputs.pop() - - def test_insert_into_outputs(self): - self.graph.outputs.insert(0, self.value2) - self.assertIs(self.graph.outputs[0], self.value2) - self.assertTrue(self.value2.is_graph_output()) - self.assertIs(self.value2.graph, self.graph) - - def test_remove_from_outputs(self): - self.graph.outputs.append(self.value2) - self.graph.outputs.remove(self.value2) - self.assertNotIn(self.value2, self.graph.outputs) - self.assertFalse(self.value2.is_graph_output()) - self.assertIsNone(self.value2.graph) - - def test_clear_outputs(self): - self.graph.outputs.extend([self.value1, self.value2]) - self.graph.outputs.clear() - self.assertEqual(len(self.graph.outputs), 0) - self.assertFalse(self.value1.is_graph_output()) - self.assertIsNone(self.value1.graph) - self.assertFalse(self.value2.is_graph_output()) - self.assertIsNone(self.value2.graph) - - def test_clear_duplicated_outputs(self): - self.graph.outputs.extend([self.value1, self.value1]) - self.graph.outputs.clear() - self.assertEqual(len(self.graph.outputs), 0) - self.assertFalse(self.value1.is_graph_output()) - self.assertIsNone(self.value1.graph) - - def test_outputs_set_items(self): - self.graph.outputs.append(self.value1) - self.graph.outputs[-1] = self.value2 - self.assertNotIn(self.value1, self.graph.outputs) - self.assertIn(self.value2, self.graph.outputs) - self.assertIs(self.graph.outputs[0], self.value2) - self.assertTrue(self.value2.is_graph_output()) - self.assertIs(self.value2.graph, self.graph) - self.assertFalse(self.value1.is_graph_output()) - self.assertIsNone(self.value1.graph) - - def test_outputs_set_items_slices(self): - self.graph.outputs.extend([self.value1, self.value2]) - # Replace with one existing and one new output - self.graph.outputs[0:2] = [self.value2, self.value3] - self.assertNotIn(self.value1, self.graph.outputs) - self.assertIn(self.value2, self.graph.outputs) - self.assertIn(self.value3, self.graph.outputs) - self.assertIs(self.value2.graph, self.graph) - self.assertIs(self.value3.graph, self.graph) - self.assertTrue(self.value2.is_graph_output()) - self.assertTrue(self.value3.is_graph_output()) - self.assertFalse(self.value1.is_graph_output()) - self.assertIsNone(self.value1.graph) - - def test_take_outputs(self): - self.graph.outputs.extend([self.value1, self.value2, self.value3]) - outputs = self.graph.outputs[:2] - self.graph.outputs.clear() - self.graph.outputs.extend(outputs) - self.assertEqual(len(self.graph.outputs), 2) - self.assertEqual(self.graph.outputs, [self.value1, self.value2]) - self.assertTrue(self.value1.is_graph_output()) - self.assertTrue(self.value2.is_graph_output()) - self.assertFalse(self.value3.is_graph_output()) - self.assertIs(self.value1.graph, self.graph) - self.assertIs(self.value2.graph, self.graph) - self.assertIsNone(self.value3.graph) - - def test_set_initializers(self): - self.graph.initializers["initializer1"] = self.value3 - self.assertIn("initializer1", self.graph.initializers) - self.assertTrue(self.value3.is_initializer()) - self.assertIs(self.value3.graph, self.graph) - # Replace initializer - self.value1.name = "initializer1" - self.graph.initializers["initializer1"] = self.value1 - self.assertIn("initializer1", self.graph.initializers) - self.assertTrue(self.value1.is_initializer()) - self.assertIs(self.value1.graph, self.graph) - self.assertFalse(self.value3.is_initializer()) - self.assertIsNone(self.value3.graph) - - def test_set_initializers_raises_when_key_does_not_match(self): - with self.assertRaisesRegex(ValueError, "does not match the name of the value"): - self.graph.initializers["some_key"] = self.value3 - - def test_set_initializers_raises_when_it_belongs_to_another_graph(self): - other_graph = _core.Graph(inputs=(), outputs=(), nodes=()) - other_graph.initializers["initializer1"] = self.value3 - with self.assertRaisesRegex( - ValueError, "is already an initializer of a different graph" - ): - self.graph.initializers["initializer1"] = self.value3 - # Set is ok after the value is removed from the old graph - other_graph.initializers.clear() - self.graph.initializers["initializer1"] = self.value3 - self.assertIn("initializer1", self.graph.initializers) - self.assertTrue(self.value3.is_initializer()) - self.assertIs(self.value3.graph, self.graph) - - def test_set_initializers_raises_when_value_does_not_have_a_name(self): - self.value3.name = None - with self.assertRaises(TypeError): - self.graph.initializers[None] = self.value3 - - def test_delete_initializer(self): - self.graph.initializers["initializer1"] = self.value3 - del self.graph.initializers["initializer1"] - self.assertNotIn("initializer1", self.graph.initializers) - self.assertFalse(self.value3.is_initializer()) - self.assertIsNone(self.value3.graph) - - def test_delete_initializer_raises_when_key_does_not_exist(self): - with self.assertRaises(KeyError): - del self.graph.initializers["non_existent"] - - def test_clear_initializers(self): - self.graph.initializers["initializer1"] = self.value3 - self.graph.initializers.clear() - self.assertEqual(len(self.graph.initializers), 0) - self.assertFalse(self.value3.is_initializer()) - self.assertIsNone(self.value3.graph) - - def test_pop_initializer(self): - self.graph.initializers["initializer1"] = self.value3 - popped = self.graph.initializers.pop("initializer1") - self.assertEqual(popped, self.value3) - self.assertNotIn("initializer1", self.graph.initializers) - self.assertFalse(self.value3.is_initializer()) - self.assertIsNone(self.value3.graph) - - def test_update_initializers(self): - self.graph.initializers["initializer1"] = self.value3 - new_initializer = _core.Value(name="initializer2") - self.graph.initializers.update({new_initializer.name: new_initializer}) - self.assertIn(new_initializer.name, self.graph.initializers) - self.assertTrue(new_initializer.is_initializer()) - self.assertEqual(new_initializer.graph, self.graph) - self.assertIn("initializer1", self.graph.initializers) - self.assertTrue(self.value3.is_initializer()) - self.assertEqual(self.value3.graph, self.graph) - - def test_iter_initializers(self): - self.graph.initializers["initializer1"] = self.value3 - initializers = list(self.graph.initializers.values()) - self.assertEqual(len(initializers), 1) - self.assertEqual(initializers[0].name, "initializer1") - self.assertTrue(initializers[0].is_initializer()) - self.assertEqual(initializers[0].graph, self.graph) - - def test_contains_initializer(self): - self.graph.initializers["initializer1"] = self.value3 - self.assertIn("initializer1", self.graph.initializers) - self.assertTrue(self.value3.is_initializer()) - self.assertEqual(self.value3.graph, self.graph) - - def test_not_contains_initializer(self): - self.assertNotIn("non_existent", self.graph.initializers) - self.assertFalse(self.value3.is_initializer()) - self.assertIsNone(self.value3.graph) - - def test_initializer_can_be_added_as_input(self): - self.graph.initializers["initializer1"] = self.value3 - self.graph.inputs.append(self.value3) - self.assertIn(self.value3, self.graph.inputs) - self.assertTrue(self.value3.is_graph_input()) - self.assertIs(self.value3.graph, self.graph) - self.assertFalse(self.value3.is_graph_output()) - self.assertTrue(self.value3.is_initializer()) - - def test_initializer_can_be_added_as_output(self): - self.graph.initializers["initializer1"] = self.value3 - self.graph.outputs.append(self.value3) - self.assertIn(self.value3, self.graph.outputs) - self.assertTrue(self.value3.is_graph_output()) - self.assertIs(self.value3.graph, self.graph) - self.assertFalse(self.value3.is_graph_input()) - self.assertTrue(self.value3.is_initializer()) - - -class ModelTest(unittest.TestCase): - def test_graphs_returns_all_subgraphs(self): - # main_graph: nodes=[a,b,c,d,>,if], edges=[(a,>),(b,>),(>,if)], subgraphs={if:[then_graph,else_graph]} - # then_graph: nodes=[sub], edges=[(c,sub),(d,sub)] - # else_graph: nodes=[add], edges=[(c,add),(d,add)] - v0 = _core.Value(name="va") - v1 = _core.Value(name="vb") - v2 = _core.Value(name="vc") - v3 = _core.Value(name="vd") - node0 = _core.Node("", "a", inputs=(v0,), num_outputs=1) - node1 = _core.Node("", "b", inputs=(v1,), num_outputs=1) - node2 = _core.Node("", "c", inputs=(v2,), num_outputs=1) - node3 = _core.Node("", "d", inputs=(v3,), num_outputs=1) - node4 = _core.Node( - "", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 - ) - node5 = _core.Node( - "", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1 - ) - node6 = _core.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) - then_graph = _core.Graph( - inputs=(), - outputs=(node4.outputs[0],), - nodes=(node4,), - name="then_graph", - ) - else_graph = _core.Graph( - inputs=(), - outputs=(node5.outputs[0],), - nodes=(node5,), - name="else_graph", - ) - node7 = _core.Node( - "", - "if", - inputs=(node6.outputs[0],), - num_outputs=1, - attributes=[ - ir.AttrGraph("then_branch", then_graph), - ir.AttrGraph("else_branch", else_graph), - ], - ) - main_graph = _core.Graph( - inputs=(v0, v1, v2, v3), - outputs=(node7.outputs[0],), - nodes=(node0, node1, node2, node6, node7), - name="main_graph", - ) - model = _core.Model(main_graph, ir_version=10) - self.assertEqual( - tuple(model.graphs()), - (main_graph, then_graph, else_graph), - ) - - -class TypeTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("tensor", _core.TensorType(ir.DataType.FLOAT)), - ("sequence", _core.SequenceType(_core.TensorType(ir.DataType.BOOL))), - ("optional", _core.OptionalType(_core.TensorType(ir.DataType.FLOAT16))), - ( - "sequence_optional", - _core.SequenceType(_core.OptionalType(_core.TensorType(ir.DataType.INT8))), - ), - ( - "optional_sequence", - _core.OptionalType(_core.SequenceType(_core.TensorType(ir.DataType.INT16))), - ), - ] - ) - def test_type_is_hashable(self, _: str, type_: ir.TypeProtocol): - self.assertIsInstance(hash(type_), int) - self.assertIn(type_, {type_}) # type: ignore - # Assert that a different type object can still be matched - self.assertIn(copy.deepcopy(type_), {type_}) # type: ignore - - def test_type_is_comparable(self): - self.assertEqual( - _core.TensorType(ir.DataType.FLOAT), _core.TensorType(ir.DataType.FLOAT) - ) - self.assertNotEqual( - _core.TensorType(ir.DataType.FLOAT), _core.TensorType(ir.DataType.FLOAT16) - ) - - @parameterized.parameterized.expand( - [ - ("tensor", _core.TensorType(ir.DataType.FLOAT)), - ("sequence", _core.SequenceType(_core.TensorType(ir.DataType.BOOL))), - ("optional", _core.OptionalType(_core.TensorType(ir.DataType.FLOAT16))), - ( - "sequence_optional", - _core.SequenceType(_core.OptionalType(_core.TensorType(ir.DataType.INT8))), - ), - ( - "optional_sequence", - _core.OptionalType(_core.SequenceType(_core.TensorType(ir.DataType.INT16))), - ), - ] - ) - def test_composite_type_is_comparable(self, _: str, type_: ir.TypeProtocol): - self.assertEqual(type_, type_) - # Equal even if deep-copied - self.assertEqual(type_, copy.deepcopy(type_)) - - -class AttrTest(unittest.TestCase): - """Test the Attr class.""" - - def test_init(self): - attr = _core.Attr("test", ir.AttributeType.INT, 42, doc_string="test string") - self.assertEqual(attr.name, "test") - self.assertEqual(attr.value, 42) - self.assertEqual(attr.type, ir.AttributeType.INT) - self.assertEqual(attr.doc_string, "test string") - - def test_as_float(self): - attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0) - self.assertEqual(attr.as_float(), 42.0) - - attr_int_value = _core.Attr("test", ir.AttributeType.FLOAT, 42) - self.assertEqual(attr_int_value.as_float(), 42.0) - - def test_as_int(self): - attr = _core.Attr("test", ir.AttributeType.INT, 0) - self.assertEqual(attr.as_int(), 0) - - def test_as_string(self): - attr = _core.Attr("test", ir.AttributeType.STRING, "test string") - self.assertEqual(attr.as_string(), "test string") - - def test_as_tensor(self): - attr = _core.Attr("test", ir.AttributeType.TENSOR, ir.tensor([42.0])) - np.testing.assert_equal(attr.as_tensor().numpy(), np.array([42.0])) - - def test_as_graph(self): - attr = _core.Attr("test", ir.AttributeType.GRAPH, _core.Graph((), (), nodes=())) - self.assertIsInstance(attr.as_graph(), _core.Graph) - - def test_as_floats(self): - attr = _core.Attr("test", ir.AttributeType.FLOATS, [42.0]) - self.assertEqual(attr.as_floats(), [42.0]) - - def test_as_ints(self): - attr = _core.Attr("test", ir.AttributeType.INTS, [42]) - self.assertEqual(attr.as_ints(), [42]) - - def test_as_strings(self): - attr = _core.Attr("test", ir.AttributeType.STRINGS, ["test string", ""]) - self.assertEqual(attr.as_strings(), ["test string", ""]) - - def test_as_tensors(self): - attr = _core.Attr("test", ir.AttributeType.TENSORS, [ir.tensor([42.0])]) - np.testing.assert_equal(attr.as_tensors()[0].numpy(), np.array([42.0])) - - def test_as_graphs(self): - attr = _core.Attr("test", ir.AttributeType.GRAPHS, [_core.Graph((), (), nodes=())]) - self.assertIsInstance(attr.as_graphs()[0], _core.Graph) - - -class LazyTensorTest(unittest.TestCase): - def test_lazy_tensor_initialization(self): - def tensor_fn(): - return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64) - - lazy_tensor = _core.LazyTensor( - tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,)) - ) - self.assertEqual(lazy_tensor.dtype, ir.DataType.INT64) - self.assertEqual(lazy_tensor.shape, (3,)) - - def test_lazy_tensor_numpy(self): - def tensor_fn(): - return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64) - - lazy_tensor = _core.LazyTensor( - tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,)) - ) - np.testing.assert_array_equal(lazy_tensor.numpy(), np.array([1, 2, 3])) - - def test_lazy_tensor_tobytes(self): - def tensor_fn(): - return ir.tensor([1, 2, 3], dtype=ir.DataType.INT64) - - lazy_tensor = _core.LazyTensor( - tensor_fn, dtype=ir.DataType.INT64, shape=ir.Shape((3,)) - ) - self.assertEqual( - lazy_tensor.tobytes(), - b"\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_display.py b/onnxscript/ir/_display.py deleted file mode 100644 index 2fc62114c2..0000000000 --- a/onnxscript/ir/_display.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Internal utilities for displaying the intermediate representation of a model. - -NOTE: All third-party imports should be scoped and imported only when used to avoid -importing unnecessary dependencies. -""" -# pylint: disable=import-outside-toplevel - -from __future__ import annotations - -from typing import Any - - -def require_rich() -> Any: - """Raise an ImportError if rich is not installed.""" - try: - import rich - except ImportError: - return None - return rich - - -class PrettyPrintable: - def display(self, *, page: bool = False) -> None: - """Pretty print the object. - - Args: - page: Whether to page the output. - """ - rich = require_rich() - text = str(self) - - if rich is None: - print(text) - # Color print this message - print( - f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m" - ) - return - - if page: - import rich.console - - console = rich.console.Console() - with console.pager(): - console.print(text) - else: - rich.print(text) diff --git a/onnxscript/ir/_display_test.py b/onnxscript/ir/_display_test.py deleted file mode 100644 index ee745b4844..0000000000 --- a/onnxscript/ir/_display_test.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Test display() methods in various classes.""" - -import contextlib -import unittest - -import numpy as np - -import onnxscript.ir as ir - - -class DisplayTest(unittest.TestCase): - def test_tensor_display_does_not_raise_on_nan_values(self): - array_with_nan = np.array([np.inf, -np.inf, np.nan, 5, -10], dtype=np.float32) - tensor = ir.Tensor(array_with_nan, dtype=ir.DataType.FLOAT) - with contextlib.redirect_stdout(None): - tensor.display() - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py deleted file mode 100644 index 9ecce9fed3..0000000000 --- a/onnxscript/ir/_enums.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""ONNX IR enums that matches the ONNX spec.""" - -from __future__ import annotations - -import enum - -import ml_dtypes -import numpy as np - - -class AttributeType(enum.IntEnum): - """Enum for the types of ONNX attributes.""" - - UNDEFINED = 0 - FLOAT = 1 - INT = 2 - STRING = 3 - TENSOR = 4 - GRAPH = 5 - FLOATS = 6 - INTS = 7 - STRINGS = 8 - TENSORS = 9 - GRAPHS = 10 - SPARSE_TENSOR = 11 - SPARSE_TENSORS = 12 - TYPE_PROTO = 13 - TYPE_PROTOS = 14 - - def __repr__(self) -> str: - return self.name - - def __str__(self) -> str: - return self.__repr__() - - -class DataType(enum.IntEnum): - """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``.""" - - # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64, - # but we should stick to the names used in the ONNX spec for consistency. - UNDEFINED = 0 - FLOAT = 1 - UINT8 = 2 - INT8 = 3 - UINT16 = 4 - INT16 = 5 - INT32 = 6 - INT64 = 7 - STRING = 8 - BOOL = 9 - FLOAT16 = 10 - DOUBLE = 11 - UINT32 = 12 - UINT64 = 13 - COMPLEX64 = 14 - COMPLEX128 = 15 - BFLOAT16 = 16 - FLOAT8E4M3FN = 17 - FLOAT8E4M3FNUZ = 18 - FLOAT8E5M2 = 19 - FLOAT8E5M2FNUZ = 20 - UINT4 = 21 - INT4 = 22 - FLOAT4E2M1 = 23 - - @classmethod - def from_numpy(cls, dtype: np.dtype) -> DataType: - """Returns the ONNX data type for the numpy dtype. - - Raises: - TypeError: If the data type is not supported by ONNX. - """ - if dtype in _NP_TYPE_TO_DATA_TYPE: - return cls(_NP_TYPE_TO_DATA_TYPE[dtype]) - - if np.issubdtype(dtype, np.str_): - return DataType.STRING - - # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18) - # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py - if hasattr(dtype, "names"): - if dtype.names == ("bfloat16",): - return DataType.BFLOAT16 - if dtype.names == ("e4m3fn",): - return DataType.FLOAT8E4M3FN - if dtype.names == ("e4m3fnuz",): - return DataType.FLOAT8E4M3FNUZ - if dtype.names == ("e5m2",): - return DataType.FLOAT8E5M2 - if dtype.names == ("e5m2fnuz",): - return DataType.FLOAT8E5M2FNUZ - if dtype.names == ("uint4",): - return DataType.UINT4 - if dtype.names == ("int4",): - return DataType.INT4 - if dtype.names == ("float4e2m1",): - return DataType.FLOAT4E2M1 - raise TypeError(f"Unsupported numpy data type: {dtype}") - - @classmethod - def from_short_name(cls, short_name: str) -> DataType: - """Returns the ONNX data type for the short name. - - Raises: - TypeError: If the short name is not available for the data type. - """ - if short_name not in _SHORT_NAME_TO_DATA_TYPE: - raise TypeError(f"Unknown short name: {short_name}") - return cls(_SHORT_NAME_TO_DATA_TYPE[short_name]) - - @property - def itemsize(self) -> float: - """Returns the size of the data type in bytes.""" - return _ITEMSIZE_MAP[self] - - def numpy(self) -> np.dtype: - """Returns the numpy dtype for the ONNX data type. - - Raises: - TypeError: If the data type is not supported by numpy. - """ - if self not in _DATA_TYPE_TO_NP_TYPE: - raise TypeError(f"Numpy does not support ONNX data type: {self}") - return _DATA_TYPE_TO_NP_TYPE[self] - - def short_name(self) -> str: - """Returns the short name of the data type. - - The short name is a string that is used to represent the data type in a more - compact form. For example, the short name for `DataType.FLOAT` is "f32". - To get the corresponding data type back, call ``from_short_name`` on a string. - - Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py - - Raises: - TypeError: If the short name is not available for the data type. - """ - if self not in _DATA_TYPE_TO_SHORT_NAME: - raise TypeError(f"Short name not available for ONNX data type: {self}") - return _DATA_TYPE_TO_SHORT_NAME[self] - - def __repr__(self) -> str: - return self.name - - def __str__(self) -> str: - return self.__repr__() - - -_ITEMSIZE_MAP = { - DataType.FLOAT: 4, - DataType.UINT8: 1, - DataType.INT8: 1, - DataType.UINT16: 2, - DataType.INT16: 2, - DataType.INT32: 4, - DataType.INT64: 8, - DataType.STRING: 1, - DataType.BOOL: 1, - DataType.FLOAT16: 2, - DataType.DOUBLE: 8, - DataType.UINT32: 4, - DataType.UINT64: 8, - DataType.COMPLEX64: 8, - DataType.COMPLEX128: 16, - DataType.BFLOAT16: 2, - DataType.FLOAT8E4M3FN: 1, - DataType.FLOAT8E4M3FNUZ: 1, - DataType.FLOAT8E5M2: 1, - DataType.FLOAT8E5M2FNUZ: 1, - DataType.UINT4: 0.5, - DataType.INT4: 0.5, - DataType.FLOAT4E2M1: 0.5, -} - - -# We use ml_dtypes to support dtypes that are not in numpy. -_NP_TYPE_TO_DATA_TYPE = { - np.dtype("bool"): DataType.BOOL, - np.dtype("complex128"): DataType.COMPLEX128, - np.dtype("complex64"): DataType.COMPLEX64, - np.dtype("float16"): DataType.FLOAT16, - np.dtype("float32"): DataType.FLOAT, - np.dtype("float64"): DataType.DOUBLE, - np.dtype("int16"): DataType.INT16, - np.dtype("int32"): DataType.INT32, - np.dtype("int64"): DataType.INT64, - np.dtype("int8"): DataType.INT8, - np.dtype("object"): DataType.STRING, - np.dtype("uint16"): DataType.UINT16, - np.dtype("uint32"): DataType.UINT32, - np.dtype("uint64"): DataType.UINT64, - np.dtype("uint8"): DataType.UINT8, - np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16, - np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN, - np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ, - np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2, - np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ, - np.dtype(ml_dtypes.int4): DataType.INT4, - np.dtype(ml_dtypes.uint4): DataType.UINT4, -} - -# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE -_NP_TYPE_TO_DATA_TYPE.update( - {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1} - if hasattr(ml_dtypes, "float4_e2m1fn") - else {} -) - -# ONNX DataType to Numpy dtype. -_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()} - -_DATA_TYPE_TO_SHORT_NAME = { - DataType.UNDEFINED: "undefined", - DataType.BFLOAT16: "bf16", - DataType.DOUBLE: "f64", - DataType.FLOAT: "f32", - DataType.FLOAT16: "f16", - DataType.FLOAT8E4M3FN: "f8e4m3fn", - DataType.FLOAT8E5M2: "f8e5m2", - DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz", - DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz", - DataType.FLOAT4E2M1: "f4e2m1", - DataType.COMPLEX64: "c64", - DataType.COMPLEX128: "c128", - DataType.INT4: "i4", - DataType.INT8: "i8", - DataType.INT16: "i16", - DataType.INT32: "i32", - DataType.INT64: "i64", - DataType.BOOL: "b8", - DataType.UINT4: "u4", - DataType.UINT8: "u8", - DataType.UINT16: "u16", - DataType.UINT32: "u32", - DataType.UINT64: "u64", - DataType.STRING: "s", -} - -_SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()} diff --git a/onnxscript/ir/_enums_test.py b/onnxscript/ir/_enums_test.py deleted file mode 100644 index 906bf7b572..0000000000 --- a/onnxscript/ir/_enums_test.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pylint: disable=protected-access -import unittest - -import ml_dtypes -import numpy as np -import onnx -import onnx._custom_element_types -import parameterized - -from onnxscript.ir import _enums - - -class DataTypeTest(unittest.TestCase): - def test_enums_are_the_same_as_spec(self): - self.assertEqual(_enums.DataType.FLOAT, onnx.TensorProto.FLOAT) - self.assertEqual(_enums.DataType.UINT8, onnx.TensorProto.UINT8) - self.assertEqual(_enums.DataType.INT8, onnx.TensorProto.INT8) - self.assertEqual(_enums.DataType.UINT16, onnx.TensorProto.UINT16) - self.assertEqual(_enums.DataType.INT16, onnx.TensorProto.INT16) - self.assertEqual(_enums.DataType.INT32, onnx.TensorProto.INT32) - self.assertEqual(_enums.DataType.INT64, onnx.TensorProto.INT64) - self.assertEqual(_enums.DataType.STRING, onnx.TensorProto.STRING) - self.assertEqual(_enums.DataType.BOOL, onnx.TensorProto.BOOL) - self.assertEqual(_enums.DataType.FLOAT16, onnx.TensorProto.FLOAT16) - self.assertEqual(_enums.DataType.DOUBLE, onnx.TensorProto.DOUBLE) - self.assertEqual(_enums.DataType.UINT32, onnx.TensorProto.UINT32) - self.assertEqual(_enums.DataType.UINT64, onnx.TensorProto.UINT64) - self.assertEqual(_enums.DataType.COMPLEX64, onnx.TensorProto.COMPLEX64) - self.assertEqual(_enums.DataType.COMPLEX128, onnx.TensorProto.COMPLEX128) - self.assertEqual(_enums.DataType.BFLOAT16, onnx.TensorProto.BFLOAT16) - self.assertEqual(_enums.DataType.FLOAT8E4M3FN, onnx.TensorProto.FLOAT8E4M3FN) - self.assertEqual(_enums.DataType.FLOAT8E4M3FNUZ, onnx.TensorProto.FLOAT8E4M3FNUZ) - self.assertEqual(_enums.DataType.FLOAT8E5M2, onnx.TensorProto.FLOAT8E5M2) - self.assertEqual(_enums.DataType.FLOAT8E5M2FNUZ, onnx.TensorProto.FLOAT8E5M2FNUZ) - self.assertEqual(_enums.DataType.UINT4, onnx.TensorProto.UINT4) - self.assertEqual(_enums.DataType.INT4, onnx.TensorProto.INT4) - if hasattr(onnx.TensorProto, "FLOAT4E2M1"): - self.assertEqual(_enums.DataType.FLOAT4E2M1, onnx.TensorProto.FLOAT4E2M1) - self.assertEqual(_enums.DataType.UNDEFINED, onnx.TensorProto.UNDEFINED) - - @parameterized.parameterized.expand( - [ - ("string", np.array("some_string").dtype, _enums.DataType.STRING), - ("float64", np.dtype(np.float64), _enums.DataType.DOUBLE), - ("float32", np.dtype(np.float32), _enums.DataType.FLOAT), - ("float16", np.dtype(np.float16), _enums.DataType.FLOAT16), - ("int32", np.dtype(np.int32), _enums.DataType.INT32), - ("int16", np.dtype(np.int16), _enums.DataType.INT16), - ("int8", np.dtype(np.int8), _enums.DataType.INT8), - ("int64", np.dtype(np.int64), _enums.DataType.INT64), - ("uint8", np.dtype(np.uint8), _enums.DataType.UINT8), - ("uint16", np.dtype(np.uint16), _enums.DataType.UINT16), - ("uint32", np.dtype(np.uint32), _enums.DataType.UINT32), - ("uint64", np.dtype(np.uint64), _enums.DataType.UINT64), - ("bool", np.dtype(np.bool_), _enums.DataType.BOOL), - ("complex64", np.dtype(np.complex64), _enums.DataType.COMPLEX64), - ("complex128", np.dtype(np.complex128), _enums.DataType.COMPLEX128), - ("bfloat16", np.dtype(ml_dtypes.bfloat16), _enums.DataType.BFLOAT16), - ("float8e4m3fn", np.dtype(ml_dtypes.float8_e4m3fn), _enums.DataType.FLOAT8E4M3FN), - ( - "float8e4m3fnuz", - np.dtype(ml_dtypes.float8_e4m3fnuz), - _enums.DataType.FLOAT8E4M3FNUZ, - ), - ("float8e5m2", np.dtype(ml_dtypes.float8_e5m2), _enums.DataType.FLOAT8E5M2), - ( - "float8e5m2fnuz", - np.dtype(ml_dtypes.float8_e5m2fnuz), - _enums.DataType.FLOAT8E5M2FNUZ, - ), - ("uint4", np.dtype(ml_dtypes.uint4), _enums.DataType.UINT4), - ("int4", np.dtype(ml_dtypes.int4), _enums.DataType.INT4), - ("float4e2m1", np.dtype(ml_dtypes.float4_e2m1fn), _enums.DataType.FLOAT4E2M1), - ( - "onnx_ref_bfloat16", - onnx._custom_element_types.bfloat16, - _enums.DataType.BFLOAT16, - ), - ( - "onnx_ref_float8e4m3fn", - onnx._custom_element_types.float8e4m3fn, - _enums.DataType.FLOAT8E4M3FN, - ), - ( - "onnx_ref_float8e4m3fnuz", - onnx._custom_element_types.float8e4m3fnuz, - _enums.DataType.FLOAT8E4M3FNUZ, - ), - ( - "onnx_ref_float8e5m2", - onnx._custom_element_types.float8e5m2, - _enums.DataType.FLOAT8E5M2, - ), - ( - "onnx_ref_float8e5m2fnuz", - onnx._custom_element_types.float8e5m2fnuz, - _enums.DataType.FLOAT8E5M2FNUZ, - ), - ( - "onnx_ref_uint4", - onnx._custom_element_types.uint4, - _enums.DataType.UINT4, - ), - ("onnx_ref_int4", onnx._custom_element_types.int4, _enums.DataType.INT4), - ] - ) - def test_from_numpy_takes_np_dtype_and_returns_data_type( - self, _: str, np_dtype: np.dtype, onnx_type: _enums.DataType - ): - self.assertEqual(_enums.DataType.from_numpy(np_dtype), onnx_type) - - def test_numpy_returns_np_dtype(self): - self.assertEqual(_enums.DataType.DOUBLE.numpy(), np.dtype(np.float64)) - - def test_itemsize_returns_size_of_data_type_in_bytes(self): - self.assertEqual(_enums.DataType.DOUBLE.itemsize, 8) - self.assertEqual(_enums.DataType.INT4.itemsize, 0.5) - - def test_repr_and_str_return_name(self): - self.assertEqual(str(_enums.DataType.DOUBLE), "DOUBLE") - self.assertEqual(repr(_enums.DataType.DOUBLE), "DOUBLE") - - def test_short_name_conversion(self): - for dtype in _enums.DataType: - short_name = dtype.short_name() - self.assertEqual(_enums.DataType.from_short_name(short_name), dtype) - - def test_access_by_name(self): - self.assertEqual(_enums.DataType["FLOAT"], _enums.DataType.FLOAT) - self.assertEqual(_enums.DataType["UINT8"], _enums.DataType.UINT8) - self.assertEqual(_enums.DataType["INT8"], _enums.DataType.INT8) - self.assertEqual(_enums.DataType["UINT16"], _enums.DataType.UINT16) - self.assertEqual(_enums.DataType["INT16"], _enums.DataType.INT16) - self.assertEqual(_enums.DataType["INT32"], _enums.DataType.INT32) - self.assertEqual(_enums.DataType["INT64"], _enums.DataType.INT64) - self.assertEqual(_enums.DataType["STRING"], _enums.DataType.STRING) - self.assertEqual(_enums.DataType["BOOL"], _enums.DataType.BOOL) - self.assertEqual(_enums.DataType["FLOAT16"], _enums.DataType.FLOAT16) - self.assertEqual(_enums.DataType["DOUBLE"], _enums.DataType.DOUBLE) - self.assertEqual(_enums.DataType["UINT32"], _enums.DataType.UINT32) - self.assertEqual(_enums.DataType["UINT64"], _enums.DataType.UINT64) - self.assertEqual(_enums.DataType["COMPLEX64"], _enums.DataType.COMPLEX64) - self.assertEqual(_enums.DataType["COMPLEX128"], _enums.DataType.COMPLEX128) - self.assertEqual(_enums.DataType["BFLOAT16"], _enums.DataType.BFLOAT16) - self.assertEqual(_enums.DataType["FLOAT8E4M3FN"], _enums.DataType.FLOAT8E4M3FN) - self.assertEqual(_enums.DataType["FLOAT8E4M3FNUZ"], _enums.DataType.FLOAT8E4M3FNUZ) - self.assertEqual(_enums.DataType["FLOAT8E5M2"], _enums.DataType.FLOAT8E5M2) - self.assertEqual(_enums.DataType["FLOAT8E5M2FNUZ"], _enums.DataType.FLOAT8E5M2FNUZ) - self.assertEqual(_enums.DataType["UINT4"], _enums.DataType.UINT4) - self.assertEqual(_enums.DataType["INT4"], _enums.DataType.INT4) - self.assertEqual(_enums.DataType["FLOAT4E2M1"], _enums.DataType.FLOAT4E2M1) - self.assertEqual(_enums.DataType["UNDEFINED"], _enums.DataType.UNDEFINED) - - -class AttributeTypeTest(unittest.TestCase): - def test_enums_are_the_same_as_spec(self): - self.assertEqual(_enums.AttributeType.FLOAT, onnx.AttributeProto.FLOAT) - self.assertEqual(_enums.AttributeType.INT, onnx.AttributeProto.INT) - self.assertEqual(_enums.AttributeType.STRING, onnx.AttributeProto.STRING) - self.assertEqual(_enums.AttributeType.TENSOR, onnx.AttributeProto.TENSOR) - self.assertEqual(_enums.AttributeType.GRAPH, onnx.AttributeProto.GRAPH) - self.assertEqual(_enums.AttributeType.FLOATS, onnx.AttributeProto.FLOATS) - self.assertEqual(_enums.AttributeType.INTS, onnx.AttributeProto.INTS) - self.assertEqual(_enums.AttributeType.STRINGS, onnx.AttributeProto.STRINGS) - self.assertEqual(_enums.AttributeType.TENSORS, onnx.AttributeProto.TENSORS) - self.assertEqual(_enums.AttributeType.GRAPHS, onnx.AttributeProto.GRAPHS) - self.assertEqual(_enums.AttributeType.SPARSE_TENSOR, onnx.AttributeProto.SPARSE_TENSOR) - self.assertEqual( - _enums.AttributeType.SPARSE_TENSORS, onnx.AttributeProto.SPARSE_TENSORS - ) - self.assertEqual(_enums.AttributeType.TYPE_PROTO, onnx.AttributeProto.TYPE_PROTO) - self.assertEqual(_enums.AttributeType.TYPE_PROTOS, onnx.AttributeProto.TYPE_PROTOS) - self.assertEqual(_enums.AttributeType.UNDEFINED, onnx.AttributeProto.UNDEFINED) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_graph_comparison.py b/onnxscript/ir/_graph_comparison.py deleted file mode 100644 index e13b8ba473..0000000000 --- a/onnxscript/ir/_graph_comparison.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Utilities for comparing IR graphs.""" - -from __future__ import annotations - -from onnxscript.ir import _core - -# NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs -# NOTE(justinchuby): A graph may be specified with a set of inputs and outputs - - -def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool: - """Return true if the two graphs are topologically equivalent, without considering initializers. - - Args: - graph1: The first graph to compare. - graph2: The second graph to compare. - - Returns: - True if the graphs are equal, False otherwise. - """ - raise NotImplementedError() diff --git a/onnxscript/ir/_graph_containers.py b/onnxscript/ir/_graph_containers.py deleted file mode 100644 index 620e73e86b..0000000000 --- a/onnxscript/ir/_graph_containers.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Tracked containers for graph.""" - -# pylint: disable=protected-access - -from __future__ import annotations - -__all__ = [ - "GraphInputs", - "GraphOutputs", -] - -import collections -from typing import TYPE_CHECKING, Iterable, SupportsIndex - -import onnxscript - -if TYPE_CHECKING: - from onnxscript.ir import _core - - -class _GraphIO(collections.UserList["_core.Value"]): - """The inputs and outputs of a Graph.""" - - def __init__(self, graph: _core.Graph, initlist=None): - self._graph = graph - # Use a ref counter to track the number of references to each value - # in the input/output list. This is used to determine when to unset the graph - # reference in the value. - # Even though a duplicated value is invalid in inputs and not recommended in outputs, - # it is still possible to have duplicated inputs/outputs in an ONNX graph so we - # need to properly handle this case and maintain the graph reference properly. - self._ref_counter: collections.Counter[_core.Value] = collections.Counter() - if initlist is not None: - initlist = tuple(initlist) # Create a copy in case initlist is a generator - for value in initlist: - self._set_graph(value) - super().__init__(initlist) - self._check_invariance() - - def _check_invariance(self) -> None: - """Check the invariance of the graph.""" - raise NotImplementedError - - def _set_graph(self, value: _core.Value) -> None: - """Set the graph for the value.""" - raise NotImplementedError - - def _maybe_unset_graph(self, value: _core.Value) -> None: - """Unset the graph for the value.""" - raise NotImplementedError - - def append(self, item: _core.Value) -> None: - """Add a new input to the graph.""" - # Perform checks first in _set_graph before modifying the data structure - self._set_graph(item) - super().append(item) - self._check_invariance() - - def extend(self, other) -> None: - """Extend the list of inputs or outputs.""" - other = tuple(other) - for item in other: - self._set_graph(item) - super().extend(other) - - def insert(self, i: int, item: _core.Value) -> None: - """Insert an input/output to the graph.""" - super().insert(i, item) - self._set_graph(item) - self._check_invariance() - - def pop(self, i: int = -1) -> _core.Value: - """Remove an input/output from the graph.""" - value = super().pop(i) - self._maybe_unset_graph(value) - self._check_invariance() - return value - - def remove(self, item: _core.Value) -> None: - """Remove an input/output from the graph.""" - super().remove(item) - self._maybe_unset_graph(item) - self._check_invariance() - - def clear(self) -> None: - """Clear the list.""" - for value in self.data: - self._maybe_unset_graph(value) - super().clear() - - def __setitem__(self, i, item) -> None: - """Replace an input/output to the node.""" - if isinstance(item, Iterable) and isinstance(i, slice): - # Modify a slice of the list - for value in self.data[i]: - self._maybe_unset_graph(value) - for value in item: - self._set_graph(value) - super().__setitem__(i, item) - self._check_invariance() - return - elif isinstance(i, SupportsIndex): - # Replace a single item - self._maybe_unset_graph(self.data[i]) - self._set_graph(item) - super().__setitem__(i, item) - self._check_invariance() - return - - raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}") - - def __getitem__(self, i): - """Get an input/output from the graph.""" - return self.data[i] - - def _unimplemented(self, *_args, **_kwargs): - """Unimplemented method.""" - raise RuntimeError("Method is not supported") - - __add__ = _unimplemented - __radd__ = _unimplemented - __iadd__ = _unimplemented - __mul__ = _unimplemented - __rmul__ = _unimplemented - copy = _unimplemented - - -class GraphInputs(_GraphIO): - """The inputs of a Graph.""" - - def _check_invariance(self) -> None: - """Check the invariance of the graph.""" - if not onnxscript.DEBUG: - return - for value in self.data: - if value._graph is self._graph: - continue - raise ValueError( - f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}" - ) - - def _set_graph(self, value: _core.Value) -> None: - """Set the graph for the value.""" - if value._graph is not None and value._graph is not self._graph: - raise ValueError( - f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first" - ) - self._ref_counter[value] += 1 - value._is_graph_input = True - value._graph = self._graph - - def _maybe_unset_graph(self, value: _core.Value) -> None: - """Unset the graph for the value.""" - assert value._graph is self._graph, "Bug: value does not belong to the graph" - self._ref_counter[value] -= 1 - if self._ref_counter[value] > 0: - # The value is still used by another graph input - return - value._is_graph_input = False - if value._owned_by_graph(): - # Keep the graph reference if the value is still an input or an initializer - return - value._graph = None - - -class GraphOutputs(_GraphIO): - """The outputs of a Graph.""" - - def _check_invariance(self) -> None: - """Check the invariance of the graph.""" - if not onnxscript.DEBUG: - return - for value in self.data: - if value._graph is self._graph: - continue - raise ValueError( - f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}" - ) - - def _set_graph(self, value: _core.Value) -> None: - """Set the graph for the value.""" - if value._graph is not None and value._graph is not self._graph: - raise ValueError( - f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first" - ) - self._ref_counter[value] += 1 - value._is_graph_output = True - value._graph = self._graph - - def _maybe_unset_graph(self, value: _core.Value) -> None: - """Unset the graph for the value.""" - assert value._graph is self._graph, "Bug: value does not belong to the graph" - self._ref_counter[value] -= 1 - if self._ref_counter[value] > 0: - # The value is still used by another graph input - return - value._is_graph_output = False - if value._owned_by_graph(): - # Keep the graph reference if the value is still an input or an initializer - return - value._graph = None - - -class GraphInitializers(collections.UserDict[str, "_core.Value"]): - """The initializers of a Graph.""" - - def __init__(self, graph: _core.Graph, dict=None, /, **kwargs): - # Perform checks first in _set_graph before modifying the data structure with super().__init__() - data = {} - if dict is not None: - data.update(dict) - if kwargs: - data.update(kwargs) - self._graph = graph - for value in data.values(): - self._set_graph(value) - - super().__init__(data) - - def _set_graph(self, value: _core.Value) -> None: - """Set the graph for the value.""" - if value._graph is not None and value._graph is not self._graph: - raise ValueError( - f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first" - ) - value._is_initializer = True - value._graph = self._graph - - def _maybe_unset_graph(self, value: _core.Value) -> None: - """Unset the graph for the value.""" - assert value._graph is self._graph, "Bug: value does not belong to the graph" - value._is_initializer = False - if value._owned_by_graph(): - # Keep the graph reference if the value is still an input or an initializer - return - value._graph = None - - def __setitem__(self, key: str, value: _core.Value) -> None: - """Set an initializer for the graph.""" - if key != value.name: - raise ValueError( - f"Key '{key}' does not match the name of the value '{value.name}'" - ) - if not isinstance(key, str): - raise TypeError(f"Key must be a string, not {type(key)}") - if key in self.data: - # If the key already exists, unset the old value - old_value = self.data[key] - self._maybe_unset_graph(old_value) - # Must call _set_graph before super().__setitem__ so that when there is an error, - # the dictionary is not modified - self._set_graph(value) - super().__setitem__(key, value) - - def __delitem__(self, key: str) -> None: - """Delete an initializer from the graph.""" - value = self.data[key] - # Must call _maybe_unset_graph before super().__delitem__ so that when there is an error, - # the dictionary is not modified - self._maybe_unset_graph(value) - super().__delitem__(key) diff --git a/onnxscript/ir/_io.py b/onnxscript/ir/_io.py deleted file mode 100644 index a83cfdbd9d..0000000000 --- a/onnxscript/ir/_io.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Load and save ONNX models.""" - -from __future__ import annotations - -__all__ = ["load", "save"] - -import os - -import onnx - -from onnxscript.ir import _core, serde -from onnxscript.ir import external_data as _external_data -from onnxscript.ir._polyfill import zip - - -def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: - """Load an ONNX model from a file. - - Args: - path: The path to the ONNX file. - format: The format of the file (e.g. protobuf, textproto, json, etc.). - If None, the format is inferred from the file extension. - - Returns: - The loaded model. - """ - # Do not use ONNX to load external data because the IR handles external data - # by doing memory mapping directly. - proto = onnx.load(path, format=format, load_external_data=False) - model = serde.deserialize_model(proto) - base_dir = os.path.dirname(path) - # Set the base directory for external data to the directory of the ONNX file - # so that relative paths are resolved correctly. - _external_data.set_base_dir(model.graph, base_dir) - return model - - -def save( - model: _core.Model, - path: str | os.PathLike, - format: str | None = None, - external_data: str | os.PathLike | None = None, - size_threshold_bytes: int = 256, -) -> None: - """Save an ONNX model to a file. - - The model remains unchanged after the call. If any existing external tensor - references the provided ``external_data`` path, it will be invalidated - after the external data is overwritten. To obtain a valid model, use :func:`load` - to load the newly saved model, or provide a different external data path that - is not currently referenced by any tensors in the model. - - Args: - model: The model to save. - path: The path to save the model to. E.g. "model.onnx". - format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.). - If None, the format is inferred from the file extension. - external_data: The relative path to save external data to. When specified, - all initializers in the model will be converted to external data and - saved to the specified directory. If None, all tensors will be saved unmodified. - That is, if a tensor in the model is already external, it will be saved - with the same external information; if the tensor is not external, - it will be serialized in the ONNX Proto message. - size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold. - Effective only when ``external_data`` is set. - - Raises: - ValueError: If the external data path is an absolute path. - """ - if external_data is not None: - if os.path.isabs(external_data): - raise ValueError( - f"The external data path must be relative to the ONNX file path, not '{external_data}'." - ) - base_dir = os.path.dirname(path) - - # Store the original initializer values so they can be restored if modify_model=False - initializer_values = tuple(model.graph.initializers.values()) - tensors = [v.const_value for v in initializer_values] - - try: - model = _external_data.unload_from_model( - model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes - ) - proto = serde.serialize_model(model) - onnx.save(proto, path, format=format) - - finally: - # Restore the original initializer values so the model is unchanged - for initializer, tensor in zip(initializer_values, tensors, strict=True): - initializer.const_value = tensor - - else: - proto = serde.serialize_model(model) - onnx.save(proto, path, format=format) diff --git a/onnxscript/ir/_io_test.py b/onnxscript/ir/_io_test.py deleted file mode 100644 index 6473827bc6..0000000000 --- a/onnxscript/ir/_io_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Unit tests for the _io module.""" - -import os -import tempfile -import unittest - -import numpy as np - -from onnxscript import ir -from onnxscript.ir import _io - - -def _create_initializer(tensor: ir.TensorProtocol) -> ir.Value: - return ir.Value( - name=tensor.name, - shape=tensor.shape, - type=ir.TensorType(tensor.dtype), - const_value=tensor, - ) - - -def _create_simple_model_with_initializers() -> ir.Model: - tensor_0 = ir.tensor([0.0], dtype=ir.DataType.FLOAT, name="initializer_0") - initializer = _create_initializer(tensor_0) - tensor_1 = ir.tensor([1.0], dtype=ir.DataType.FLOAT) - identity_node = ir.Node("", "Identity", inputs=(initializer,)) - identity_node.outputs[0].shape = ir.Shape([1]) - identity_node.outputs[0].dtype = ir.DataType.FLOAT - identity_node.outputs[0].name = "identity_0" - const_node = ir.Node( - "", - "Constant", - inputs=(), - outputs=( - ir.Value(name="const_0", shape=tensor_1.shape, type=ir.TensorType(tensor_1.dtype)), - ), - attributes=ir.convenience.convert_attributes(dict(value=tensor_1)), - ) - graph = ir.Graph( - inputs=[initializer], - outputs=[*identity_node.outputs, *const_node.outputs], - nodes=[identity_node, const_node], - initializers=[initializer], - name="test_graph", - ) - return ir.Model(graph, ir_version=10) - - -class IOFunctionsTest(unittest.TestCase): - def test_load(self): - model = _create_simple_model_with_initializers() - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "model.onnx") - _io.save(model, path) - loaded_model = _io.load(path) - self.assertEqual(loaded_model.ir_version, model.ir_version) - self.assertEqual(loaded_model.graph.name, model.graph.name) - self.assertEqual(len(loaded_model.graph.initializers), 1) - self.assertEqual(len(loaded_model.graph), 2) - np.testing.assert_array_equal( - loaded_model.graph.initializers["initializer_0"].const_value.numpy(), - np.array([0.0]), - ) - np.testing.assert_array_equal( - loaded_model.graph.node(1).attributes["value"].as_tensor().numpy(), np.array([1.0]) - ) - self.assertEqual(loaded_model.graph.inputs[0].name, "initializer_0") - self.assertEqual(loaded_model.graph.outputs[0].name, "identity_0") - self.assertEqual(loaded_model.graph.outputs[1].name, "const_0") - - def test_save_with_external_data_does_not_modify_model(self): - model = _create_simple_model_with_initializers() - self.assertIsInstance(model.graph.initializers["initializer_0"].const_value, ir.Tensor) - # There may be clean up errors on Windows, so we ignore them - with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir: - path = os.path.join(tmpdir, "model.onnx") - external_data_file = "model.data" - _io.save(model, path, external_data=external_data_file, size_threshold_bytes=0) - self.assertTrue(os.path.exists(path)) - external_data_path = os.path.join(tmpdir, external_data_file) - self.assertTrue(os.path.exists(external_data_path)) - loaded_model = _io.load(path) - - # The loaded model contains external data - initializer_tensor = loaded_model.graph.initializers["initializer_0"].const_value - self.assertIsInstance(initializer_tensor, ir.ExternalTensor) - # The attribute is not externalized - const_attr_tensor = loaded_model.graph.node(1).attributes["value"].as_tensor() - self.assertIsInstance(const_attr_tensor, ir.TensorProtoTensor) - np.testing.assert_array_equal(initializer_tensor.numpy(), np.array([0.0])) - np.testing.assert_array_equal(const_attr_tensor.numpy(), np.array([1.0])) - - # The original model is not changed and can be accessed even if the - # external data file is deleted - initializer_tensor = model.graph.initializers["initializer_0"].const_value - self.assertIsInstance(initializer_tensor, ir.Tensor) - const_attr_tensor = model.graph.node(1).attributes["value"].as_tensor() - self.assertIsInstance(const_attr_tensor, ir.Tensor) - np.testing.assert_array_equal(initializer_tensor.numpy(), np.array([0.0])) - np.testing.assert_array_equal(const_attr_tensor.numpy(), np.array([1.0])) - - def test_save_raise_when_external_data_is_not_relative_path(self): - model = _create_simple_model_with_initializers() - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "model.onnx") - external_data_file = os.path.join(tmpdir, "model.data") - with self.assertRaises(ValueError): - _io.save(model, path, external_data=external_data_file) - - def test_save_with_external_data_invalidates_obsolete_external_tensors(self): - model = _create_simple_model_with_initializers() - self.assertIsInstance(model.graph.initializers["initializer_0"].const_value, ir.Tensor) - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "model.onnx") - external_data_file = "model.data" - _io.save(model, path, external_data=external_data_file, size_threshold_bytes=0) - loaded_model = _io.load(path) - # Now if we load the model back, create a different initializer and save - # the model to the same external data file, the existing external tensor - # should be invalidated - tensor_2 = ir.tensor([2.0], dtype=ir.DataType.FLOAT, name="initializer_2") - initializer_2 = _create_initializer(tensor_2) - loaded_model.graph.initializers["initializer_2"] = initializer_2 - _io.save( - loaded_model, path, external_data=external_data_file, size_threshold_bytes=0 - ) - initializer_0_tensor = loaded_model.graph.initializers["initializer_0"].const_value - self.assertIsInstance(initializer_0_tensor, ir.ExternalTensor) - self.assertFalse(initializer_0_tensor.valid()) - with self.assertRaisesRegex(ValueError, "is invalidated"): - # The existing model has to be modified to use in memory tensors - # for the values to stay correct. Saving again should raise an error - _io.save( - loaded_model, - path, - external_data=external_data_file, - size_threshold_bytes=0, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py deleted file mode 100644 index fd425c505b..0000000000 --- a/onnxscript/ir/_linked_list.py +++ /dev/null @@ -1,283 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Mutable list for nodes in a graph with safe mutation properties.""" - -from __future__ import annotations - -from typing import Generic, Iterable, Iterator, Sequence, TypeVar, overload - -T = TypeVar("T") - - -class _LinkBox(Generic[T]): - """A link in a doubly linked list that has a reference to the actual object in the link. - - The :class:`_LinkBox` is a container for the actual object in the list. It is used to - maintain the links between the elements in the linked list. The actual object is stored in the - :attr:`value` attribute. - - By using a separate container for the actual object, we can safely remove the object from the - list without losing the links. This allows us to remove the object from the list during - iteration and place the object into a different list without breaking any chains. - - This is an internal class and should only be initialized by the :class:`DoublyLinkedSet`. - - Attributes: - prev: The previous box in the list. - next: The next box in the list. - erased: A flag to indicate if the box has been removed from the list. - owning_list: The :class:`DoublyLinkedSet` to which the box belongs. - value: The actual object in the list. - """ - - __slots__ = ("next", "owning_list", "prev", "value") - - def __init__(self, owner: DoublyLinkedSet[T], value: T | None) -> None: - """Create a new link box. - - Args: - owner: The linked list to which this box belongs. - value: The value to be stored in the link box. When the value is None, - the link box is considered erased (default). The root box of the list - should be created with a None value. - """ - self.prev: _LinkBox[T] = self - self.next: _LinkBox[T] = self - self.value: T | None = value - self.owning_list: DoublyLinkedSet[T] = owner - - @property - def erased(self) -> bool: - return self.value is None - - def erase(self) -> None: - """Remove the link from the list and detach the value from the box.""" - if self.value is None: - raise ValueError("_LinkBox is already erased") - # Update the links - prev, next_ = self.prev, self.next - prev.next, next_.prev = next_, prev - # Detach the value - self.value = None - - def __repr__(self) -> str: - return f"_LinkBox({self.value!r}, erased={self.erased}, prev={self.prev.value!r}, next={self.next.value!r})" - - -class DoublyLinkedSet(Sequence[T], Generic[T]): - """A doubly linked ordered set of nodes. - - The container can be viewed as a set as it does not allow duplicate values. The order of the - elements is maintained. One can typically treat it as a doubly linked list with list-like - methods implemented. - - Adding and removing elements from the set during iteration is safe. Moving elements - from one set to another is also safe. - - During the iteration: - - If new elements are inserted after the current node, the iterator will - iterate over them as well. - - If new elements are inserted before the current node, they will - not be iterated over in this iteration. - - If the current node is lifted and inserted in a different location, - iteration will start from the "next" node at the _original_ location. - - Time complexity: - Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n), - although accessing nodes at either end of the set is O(1). I.e. - ``linked_set[0]`` and ``linked_set[-1]`` are O(1). - - Values need to be hashable. ``None`` is not a valid value in the set. - """ - - __slots__ = ("_length", "_root", "_value_ids_to_boxes") - - def __init__(self, values: Iterable[T] | None = None) -> None: - # Using the root node simplifies the mutation implementation a lot - # The list is circular. The root node is the only node that is not a part of the list values - root_ = _LinkBox(self, None) - self._root: _LinkBox = root_ - self._length = 0 - self._value_ids_to_boxes: dict[int, _LinkBox] = {} - if values is not None: - self.extend(values) - - def __iter__(self) -> Iterator[T]: - """Iterate over the elements in the list. - - - If new elements are inserted after the current node, the iterator will - iterate over them as well. - - If new elements are inserted before the current node, they will - not be iterated over in this iteration. - - If the current node is lifted and inserted in a different location, - iteration will start from the "next" node at the _original_ location. - """ - box = self._root.next - while box is not self._root: - if box.owning_list is not self: - raise RuntimeError(f"Element {box!r} is not in the list") - if not box.erased: - assert box.value is not None - yield box.value - box = box.next - - def __reversed__(self) -> Iterator[T]: - """Iterate over the elements in the list in reverse order.""" - box = self._root.prev - while box is not self._root: - if not box.erased: - assert box.value is not None - yield box.value - box = box.prev - - def __len__(self) -> int: - assert self._length == len(self._value_ids_to_boxes), ( - "Bug in the implementation: length mismatch" - ) - return self._length - - @overload - def __getitem__(self, index: int) -> T: ... - @overload - def __getitem__(self, index: slice) -> Sequence[T]: ... - - def __getitem__(self, index): - """Get the node at the given index. - - Complexity is O(n). - """ - if isinstance(index, slice): - return tuple(self)[index] - if index >= self._length or index < -self._length: - raise IndexError( - f"Index out of range: {index} not in range [-{self._length}, {self._length})" - ) - if index < 0: - # Look up from the end of the list - iterator = reversed(self) - item = next(iterator) - for _ in range(-index - 1): - item = next(iterator) - else: - iterator = iter(self) # type: ignore[assignment] - item = next(iterator) - for _ in range(index): - item = next(iterator) - return item - - def _insert_one_after( - self, - box: _LinkBox[T], - new_value: T, - ) -> _LinkBox[T]: - """Insert a new value after the given box. - - All insertion methods should call this method to ensure that the list is updated correctly. - - Example:: - Before: A <-> B <-> C - ^v0 ^v1 ^v2 - Call: _insert_one_after(B, v3) - After: A <-> B <-> new_box <-> C - ^v0 ^v1 ^v3 ^v2 - - Args: - box: The box which the new value is to be inserted. - new_value: The new value to be inserted. - """ - if new_value is None: - raise TypeError(f"{self.__class__.__name__} does not support None values") - if box.value is new_value: - # Do nothing if the new value is the same as the old value - return box - if box.owning_list is not self: - raise ValueError(f"Value {box.value!r} is not in the list") - - if (new_value_id := id(new_value)) in self._value_ids_to_boxes: - # If the value is already in the list, remove it first - self.remove(new_value) - - # Create a new _LinkBox for the new value - new_box = _LinkBox(self, new_value) - # original_box <=> original_next - # becomes - # original_box <=> new_box <=> original_next - original_next = box.next - box.next = new_box - new_box.prev = box - new_box.next = original_next - original_next.prev = new_box - - # Be sure to update the length and mapping - self._length += 1 - self._value_ids_to_boxes[new_value_id] = new_box - - return new_box - - def _insert_many_after( - self, - box: _LinkBox[T], - new_values: Iterable[T], - ): - """Insert multiple new values after the given box.""" - insertion_point = box - for new_value in new_values: - insertion_point = self._insert_one_after(insertion_point, new_value) - - def remove(self, value: T) -> None: - """Remove a node from the list.""" - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - box = self._value_ids_to_boxes[value_id] - # Remove the link box and detach the value from the box - box.erase() - - # Be sure to update the length and mapping - self._length -= 1 - del self._value_ids_to_boxes[value_id] - - def append(self, value: T) -> None: - """Append a node to the list.""" - _ = self._insert_one_after(self._root.prev, value) - - def extend( - self, - values: Iterable[T], - ) -> None: - for value in values: - self.append(value) - - def insert_after( - self, - value: T, - new_values: Iterable[T], - ) -> None: - """Insert new nodes after the given node. - - Args: - value: The value after which the new values are to be inserted. - new_values: The new values to be inserted. - """ - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - insertion_point = self._value_ids_to_boxes[value_id] - return self._insert_many_after(insertion_point, new_values) - - def insert_before( - self, - value: T, - new_values: Iterable[T], - ) -> None: - """Insert new nodes before the given node. - - Args: - value: The value before which the new values are to be inserted. - new_values: The new values to be inserted. - """ - if (value_id := id(value)) not in self._value_ids_to_boxes: - raise ValueError(f"Value {value!r} is not in the list") - insertion_point = self._value_ids_to_boxes[value_id].prev - return self._insert_many_after(insertion_point, new_values) - - def __repr__(self) -> str: - return f"DoublyLinkedSet({list(self)})" diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py deleted file mode 100644 index ead022bf2e..0000000000 --- a/onnxscript/ir/_linked_list_test.py +++ /dev/null @@ -1,387 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Unit tests for the _linked_list module.""" - -from __future__ import annotations - -import unittest - -import parameterized - -from onnxscript.ir import _linked_list - - -class _TestElement: - def __init__(self, value): - self.value = value - - def __repr__(self) -> str: - return f"_TestElement({self.value})" - - -class DoublyLinkedSetTest(unittest.TestCase): - def test_empty_list(self): - linked_list = _linked_list.DoublyLinkedSet() - self.assertEqual(len(linked_list), 0) - self.assertEqual(list(linked_list), []) - self.assertEqual(list(reversed(linked_list)), []) - with self.assertRaises(IndexError): - _ = linked_list[0] - with self.assertRaises(IndexError): - _ = linked_list[-1] - - def test_append_single_element(self): - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual(linked_list[0], elem) - self.assertEqual(linked_list[-1], elem) - self.assertEqual(list(linked_list), [elem]) - self.assertEqual(list(reversed(linked_list)), [elem]) - with self.assertRaises(IndexError): - _ = linked_list[1] - with self.assertRaises(IndexError): - _ = linked_list[-2] - - def test_append_multiple_elements(self): - linked_list = _linked_list.DoublyLinkedSet() - elems = [_TestElement(i) for i in range(3)] - for elem in elems: - linked_list.append(elem) - - self.assertEqual(len(linked_list), 3) - self.assertEqual(linked_list[0], elems[0]) - self.assertEqual(linked_list[1], elems[1]) - self.assertEqual(linked_list[2], elems[2]) - self.assertEqual(linked_list[-1], elems[2]) - self.assertEqual(linked_list[-2], elems[1]) - self.assertEqual(linked_list[-3], elems[0]) - self.assertEqual(list(linked_list), elems) - self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) - - def test_extend(self): - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - self.assertEqual(len(linked_list), 3) - self.assertEqual(linked_list[0], elems[0]) - self.assertEqual(linked_list[1], elems[1]) - self.assertEqual(linked_list[2], elems[2]) - self.assertEqual(linked_list[-1], elems[2]) - self.assertEqual(linked_list[-2], elems[1]) - self.assertEqual(linked_list[-3], elems[0]) - self.assertEqual(list(linked_list), elems) - self.assertEqual(list(reversed(linked_list)), list(reversed(elems))) - - @parameterized.parameterized.expand( - [ - ("single_element", [0], 0, [1], [0, 1]), - ("single_element_negative_index", [0], -1, [1], [0, 1]), - ("multiple_elements", [0], 0, [1, 2], [0, 1, 2]), - ("multiple_elements_negative_index", [0], -1, [1, 2], [0, 1, 2]), - ( - "multiple_original_elements_insert_at_start", - [0, 1, 2], - 0, - [42, 43], - [0, 42, 43, 1, 2], - ), - ( - "multiple_original_elements_insert_at_middle", - [0, 1, 2], - 1, - [42, 43], - [0, 1, 42, 43, 2], - ), - ( - "multiple_original_elements_insert_at_end", - [0, 1, 2], - 2, - [42, 43], - [0, 1, 2, 42, 43], - ), - ] - ) - def test_insert_after( - self, _: str, original: list[int], location: int, insertion: list[int], expected: list - ) -> None: - # Construct the original list - elems = [_TestElement(i) for i in original] - linked_list = _linked_list.DoublyLinkedSet(elems) - - # Create the new elements - new_elements = [_TestElement(i) for i in insertion] - linked_list.insert_after(elems[location], new_elements) - - # Check the list - self.assertEqual(len(linked_list), len(expected)) - self.assertEqual([elem.value for elem in linked_list], expected) - - @parameterized.parameterized.expand( - [ - ("single_element", [0], 0, [1], [1, 0]), - ("single_element_negative_index", [0], -1, [1], [1, 0]), - ("multiple_elements", [0], 0, [1, 3], [1, 3, 0]), - ("multiple_elements_negative_index", [0], -1, [1, 3], [1, 3, 0]), - ( - "multiple_original_elements_insert_at_start", - [0, 1, 2], - 0, - [42, 43], - [42, 43, 0, 1, 2], - ), - ( - "multiple_original_elements_insert_at_middle", - [0, 1, 2], - 1, - [42, 43], - [0, 42, 43, 1, 2], - ), - ( - "multiple_original_elements_insert_at_end", - [0, 1, 2], - 2, - [42, 43], - [0, 1, 42, 43, 2], - ), - ] - ) - def test_insert_before( - self, _: str, original: list[int], location: int, insertion: list[int], expected: list - ) -> None: - # Construct the original list - elems = [_TestElement(i) for i in original] - linked_list = _linked_list.DoublyLinkedSet(elems) - - # Create the new elements - new_elements = [_TestElement(i) for i in insertion] - linked_list.insert_before(elems[location], new_elements) - - # Check the list - self.assertEqual(len(linked_list), len(expected)) - self.assertEqual([elem.value for elem in linked_list], expected) - self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) - - @parameterized.parameterized.expand( - [ - ("start", 0, [1, 2]), - ("middle", 1, [0, 2]), - ("end", 2, [0, 1]), - ("start_negative", -1, [0, 1]), - ("middle_negative", -2, [0, 2]), - ("end_negative", -3, [1, 2]), - ] - ) - def test_remove(self, _: str, index: int, expected: list[int]) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - linked_list.remove(elems[index]) - - self.assertEqual(len(linked_list), 2) - self.assertEqual([elem.value for elem in linked_list], expected) - self.assertEqual([elem.value for elem in reversed(linked_list)], expected[::-1]) - - def test_remove_raises_when_element_not_found(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - with self.assertRaises(ValueError): - linked_list.remove(_TestElement(3)) - - def test_remove_raises_when_element_is_already_removed(self) -> None: - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - linked_list.remove(elem) - - with self.assertRaises(ValueError): - linked_list.remove(elem) - - def test_append_self_does_nothing(self) -> None: - linked_list = _linked_list.DoublyLinkedSet() - elem = _TestElement(0) - linked_list.append(elem) - - linked_list.append(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual(linked_list[0], elem) - self.assertEqual(list(linked_list), [elem]) - self.assertEqual(list(reversed(linked_list)), [elem]) - - def test_append_supports_appending_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - - linked_list.append(elems[1]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [1, 2, 0]) - - def test_extend_supports_extending_elements_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.extend(elems[::-1]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [2, 1, 0]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [0, 1, 2]) - - def test_insert_after_supports_inserting_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.insert_after(elems[0], [elems[2]]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [0, 2, 1]) - - def test_insert_before_supports_inserting_element_from_the_same_list(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - linked_list.insert_before(elems[0], [elems[2]]) - - self.assertEqual(len(linked_list), 3) - self.assertEqual([elem.value for elem in linked_list], [2, 0, 1]) - - def test_iterator_supports_mutation_during_iteration_current_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elem) - - self.assertEqual(len(linked_list), 2) - self.assertEqual([elem.value for elem in linked_list], [0, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) - - def test_iterator_supports_mutation_during_iteration_previous_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elem) - linked_list.remove(elems[0]) - - self.assertEqual(len(linked_list), 1) - self.assertEqual([elem.value for elem in linked_list], [2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2]) - - def test_iterator_supports_mutation_during_iteration_next_element(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - for elem in linked_list: - if elem.value == 1: - linked_list.remove(elems[2]) - linked_list.remove(elem) - - self.assertEqual(len(linked_list), 1) - self.assertEqual([elem.value for elem in linked_list], [0]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [0]) - - def test_iterator_supports_mutation_in_nested_iteration_right_of_iterator(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if elem2.value == 1: - linked_list.remove(elem2) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 0, 2]) - self.assertEqual([elem.value for elem in linked_list], [0, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 0]) - - def test_iterator_supports_mutation_in_nested_iteration_when_iter_is_self(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if elem2.value == 0: # Remove the element the current iterator points to - linked_list.remove(elem2) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 1, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 1, 2, 1, 2]) - self.assertEqual([elem.value for elem in linked_list], [1, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) - - def test_iterator_supports_mutation_in_nested_iteration_left_of_iterator(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - iter1_visited = [] - iter2_visited = [] - for elem in linked_list: - iter1_visited.append(elem.value) - for elem2 in linked_list: - iter2_visited.append(elem2.value) - if ( - elem.value == 1 and elem2.value == 0 - ): # Remove the element before the current iterator points to - linked_list.remove(elems[0]) - - self.assertEqual(len(linked_list), 2) - self.assertEqual(iter1_visited, [0, 1, 2]) - self.assertEqual(iter2_visited, [0, 1, 2, 0, 1, 2, 1, 2]) - self.assertEqual([elem.value for elem in linked_list], [1, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 1]) - - def test_insert_after_supports_element_from_different_list_during_iteration(self) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - other_linked_list = _linked_list.DoublyLinkedSet() - other_elem = _TestElement(42) - other_linked_list.append(other_elem) - - for elem in linked_list: - if elem.value == 1: - linked_list.insert_after(elem, [other_elem]) - - self.assertEqual(len(linked_list), 4) - self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) - # Other list remains unchanged - self.assertEqual(len(other_linked_list), 1) - self.assertEqual([elem.value for elem in other_linked_list], [42]) - - def test_insert_after_supports_taking_elements_from_another_doubly_linked_list( - self, - ) -> None: - elems = [_TestElement(i) for i in range(3)] - linked_list = _linked_list.DoublyLinkedSet(elems) - other_linked_list = _linked_list.DoublyLinkedSet() - other_elem = _TestElement(42) - other_linked_list.append(other_elem) - - linked_list.insert_after(elems[1], other_linked_list) - - self.assertEqual(len(linked_list), 4) - self.assertEqual([elem.value for elem in linked_list], [0, 1, 42, 2]) - self.assertEqual([elem.value for elem in reversed(linked_list)], [2, 42, 1, 0]) - # Other list remains unchanged - self.assertEqual(len(other_linked_list), 1) - self.assertEqual([elem.value for elem in other_linked_list], [42]) - - @parameterized.parameterized.expand( - [(s, t, p) for s in [-2, 0, 2, 3] for t in [2, -1, -2] for p in [-3, -1, 1, 2]] - ) - def test_get_item_slice(self, start, stop, step): - elems = [_TestElement(i) for i in range(5)] - linked_list = _linked_list.DoublyLinkedSet(elems) - self.assertEqual(len(linked_list), 5) - self.assertEqual(list(linked_list[start:stop:step]), elems[start:stop:step]) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py deleted file mode 100644 index 77db7cc410..0000000000 --- a/onnxscript/ir/_metadata.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Class for storing metadata about the IR objects.""" - -from __future__ import annotations - -import collections -from typing import Any, Mapping - - -class MetadataStore(collections.UserDict): - """Class for storing metadata about the IR objects. - - Metadata is stored as key-value pairs. The keys are strings and the values - can be any Python object. - - The metadata store also supports marking keys as invalid. This is useful - when a pass wants to mark a key that needs to be recomputed. - """ - - def __init__(self, data: Mapping[str, Any] | None = None, /) -> None: - super().__init__(data) - self._invalid_keys: set[str] = set() - - def __setitem__(self, key: str, item: Any) -> None: - self.data[key] = item - self._invalid_keys.discard(key) - - def invalidate(self, key: str) -> None: - self._invalid_keys.add(key) - - def is_valid(self, key: str) -> bool: - """Returns whether the value is valid. - - Note that default values (None) are not necessarily invalid. For example, - a shape that is unknown (None) may be still valid if shape inference has - determined that the shape is unknown. - - Whether a value is valid is solely determined by the user that sets the value. - """ - return key not in self._invalid_keys - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.data!r}, invalid_keys={self._invalid_keys!r})" diff --git a/onnxscript/ir/_name_authority.py b/onnxscript/ir/_name_authority.py deleted file mode 100644 index ab12be532d..0000000000 --- a/onnxscript/ir/_name_authority.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Auxiliary class for managing names in the IR.""" - -from __future__ import annotations - -from onnxscript.ir import _core - - -class NameAuthority: - """Class for giving names to values and nodes in the IR. - - The names are generated in the format ``val_{value_counter}`` for values and - ``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time - a new value or node is named. - - This class keeps tracks of the names it has generated and existing names - in the graph to prevent producing duplicated names. - - .. note:: - Once a name is tracked, it will not be made available even if the node/value - is removed from the graph. It is possible to improve this behavior by keeping - track of the names that are no longer used, but it is not implemented yet. - - However, if a value/node is already named when added to the graph, - the name authority will not change its name. - It is the responsibility of the user to ensure that the names are unique - (typically by running a name-fixing pass on the graph). - - TODO(justichuby): Describe the pass when we have a reference implementation. - """ - - def __init__(self): - self._value_counter = 0 - self._node_counter = 0 - self._value_names: set[str] = set() - self._node_names: set[str] = set() - - def _unique_value_name(self) -> str: - """Generate a unique name for a value.""" - while True: - name = f"val_{self._value_counter}" - self._value_counter += 1 - if name not in self._value_names: - return name - - def _unique_node_name(self, op_type: str) -> str: - """Generate a unique name for a node.""" - while True: - name = f"node_{op_type}_{self._node_counter}" - self._node_counter += 1 - if name not in self._node_names: - return name - - def register_or_name_value(self, value: _core.Value) -> None: - # TODO(justinchuby): Record names of the initializers and graph inputs - if value.name is None: - value.name = self._unique_value_name() - # If the name is already specified, we do not change it because keeping - # track of the used names can be costly when nodes can be removed from the graph: - # How do we know if a name is no longer used? We cannot reserve unused names - # because users may want to use them. - self._value_names.add(value.name) - - def register_or_name_node(self, node: _core.Node) -> None: - if node.name is None: - node.name = self._unique_node_name(node.op_type) - # If the name is already specified, we do not change it because keeping - # track of the used names can be costly when nodes can be removed from the graph: - # How do we know if a name is no longer used? We cannot reserve unused names - # because users may want to use them. - self._node_names.add(node.name) diff --git a/onnxscript/ir/_name_authority_test.py b/onnxscript/ir/_name_authority_test.py deleted file mode 100644 index 1a0fed80cb..0000000000 --- a/onnxscript/ir/_name_authority_test.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest - -from onnxscript import ir -from onnxscript.ir import _name_authority - - -class NameAuthorityTest(unittest.TestCase): - def test_register_or_name_value(self): - name_authority = _name_authority.NameAuthority() - value = ir.Value() - name_authority.register_or_name_value(value) - self.assertEqual(value.name, "val_0") - - def test_register_or_name_node(self): - name_authority = _name_authority.NameAuthority() - node = ir.Node("", "Test", []) - name_authority.register_or_name_node(node) - self.assertEqual(node.name, "node_Test_0") - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_polyfill.py b/onnxscript/ir/_polyfill.py deleted file mode 100644 index fb6008db37..0000000000 --- a/onnxscript/ir/_polyfill.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Polyfill for Python builtin functions.""" - -import sys -from typing import Any, Sequence - -if sys.version_info >= (3, 10): - zip = zip # pylint: disable=self-assigning-variable -else: - # zip(..., strict=True) was added in Python 3.10 - # TODO: Remove this polyfill when we drop support for Python 3.9 - _python_zip = zip - - def zip(a: Sequence[Any], b: Sequence[Any], strict: bool = False): - """Polyfill for Python's zip function. - - This is a special version which only supports two Sequence inputs. - - Raises: - ValueError: If the iterables have different lengths and strict is True. - """ - if len(a) != len(b) and strict: - raise ValueError("zip() argument lengths must be equal") - return _python_zip(a, b) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py deleted file mode 100644 index fbc2c7c054..0000000000 --- a/onnxscript/ir/_protocols.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Protocols for the ONNX IR. - -This file defines the interfaces for tools to interact with the IR. The interfaces -are designed such that tools leveraging the IR can be decoupled from the IR -implementation. This allows for the implementation to evolve independently of the -tools. -""" - -# 👀 -# NOTE: Why are we using protocols, instead of abstract base classes? -# -# Protocols are more flexible than abstract base classes. Users can define their -# own classes that implement the protocols without having to inherit from a -# specific base class. For example, a user can define a custom tensor class that -# implements the TensorProtocol without explicitly inheriting, and the IR can -# work with that class without any changes. -# -# `isinstance` checks can be slower with protocols. Avoid using `isinstance` -# checks when you can. Always check for concrete classes first. -# -# NOTE: Why are we using protocols, instead of using concrete classes directly? -# -# Protocols define the interface that is typically more stable. If you find yourself -# updating the protocols, pause 🛑, and carefully make sure it is absolutely needed -# and will improve the design. If you are adding new methods, consider if the method -# should be part of the protocol or if it should be a higher level convenience function -# defined outside the protocol. - -from __future__ import annotations - -import typing -from typing import ( - Any, - Collection, - Iterable, - Iterator, - Mapping, - MutableMapping, - MutableSequence, - OrderedDict, - Protocol, - Sequence, - Tuple, -) - -from onnxscript.ir import _enums - -if typing.TYPE_CHECKING: - import numpy as np - from typing_extensions import TypeAlias - -# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = Tuple[str, str, str] - - -@typing.runtime_checkable -class ArrayCompatible(Protocol): - """Protocol for array-like objects. - - An example of an array-like object is a numpy ndarray or a PyTorch Tensor. - Read more at https://numpy.org/devdocs/user/basics.interoperability.html - """ - - def __array__(self, dtype: Any) -> np.ndarray: ... - - -@typing.runtime_checkable -class DLPackCompatible(Protocol): - """Protocol for objects that can support dlpack. - - Computation backends can call __dlpack__ to obtain the underlying data in a - tensor without copying the data. This allows use to use tensorflow tensors etc. - without copying the data. - """ - - def __dlpack__(self, *, stream: Any = ...) -> Any: - """Return PyCapsule.""" - ... - - def __dlpack_device__(self) -> Any: - """Return the device.""" - ... - - -@typing.runtime_checkable -class TensorProtocol(ArrayCompatible, DLPackCompatible, Protocol): - """Concrete tensor backed by data. - - The protocol does not specify how the data is stored. That data is exposed - through the :attr:`raw` attribute for examination, but accessing :attr:`raw` - is typically not needed. - - To use the tensor as a numpy array, call :meth:`numpy`. To convert the tensor - to a byte string for serialization, call :meth:`tobytes`. - - It is recommended to check the size of the tensor first before accessing the - underlying data, because accessing the data may be expensive and incur IO - overhead. - - Attributes: - name: The name of the tensor. - shape: The shape of the tensor. - dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum. - doc_string: Documentation string. - raw: The raw data behind this tensor. It can be anything. - size: The number of elements in the tensor. - nbytes: The number of bytes in the tensor. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - shape: ShapeProtocol - dtype: _enums.DataType - doc_string: str | None - raw: Any - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - @property - def size(self) -> int: ... - - @property - def nbytes(self) -> int: ... - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array.""" - ... - - def __array__(self, dtype: Any = None) -> np.ndarray: - """Return the tensor as a numpy array, compatible with np.array.""" - ... - - def __dlpack__(self, *, stream: Any = ...) -> Any: - """Return PyCapsule.""" - ... - - def __dlpack_device__(self) -> Any: - """Return the device.""" - ... - - def tobytes(self) -> bytes: - """Return the tensor as a byte string conformed to the ONNX specification, in little endian.""" - ... - - -@typing.runtime_checkable -class ValueProtocol(Protocol): - """Protocol for values. - - A value is a named entity that can be used to represent an input or output of a graph, - a function, or a node. The information it stores generalizes over ``ValueInfoProto`` - in the ONNX specification. - - A :class:`Value` is always not owned or owned by exactly one node. When the value is not - owned, it must be an input of a graph or a function. ``producer`` and ``index`` - are ``None``. - - When the value is owned by a node, it is an output of the node. - The node that produces the value can be accessed with :meth:`producer`. - The index of the output of the node that produces the value can be accessed with - :meth:`index`. - - To find all the nodes that use this value as an input, call :meth:`uses`. - - To check if the value is an output of a graph, call :meth:`is_graph_output`. - - Attributes: - name: The name of the value. A value is always named when it is part of a graph. - shape: The shape of the value. - type: The type of the value. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - doc_string: Documentation string. - const_value: The constant tensor is the value constant. - """ - - name: str - shape: ShapeProtocol | None - type: TypeProtocol | None - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - doc_string: str | None - const_value: TensorProtocol | None - - def producer(self) -> NodeProtocol | None: - """The node that produces this value.""" - ... - - def index(self) -> int | None: - """The index of the output of the node that produces this value.""" - ... - - def uses(self) -> Collection[tuple[NodeProtocol, int]]: - """The set of (node, input_index) with node being those that use this value as an input.""" - ... - - def is_graph_output(self) -> bool: - """Whether this value is an output of a graph.""" - ... - - -@typing.runtime_checkable -class NodeProtocol(Protocol): - """Protocol for nodes. - - A node represents an invocation of an operation on the :class:`Value` s in - the computational graph. - - A node can be optionally named. A name should typically be assigned when the - node is added to a graph. - - :attr:`domain`, :attr:`op_type`, and :attr:`overload` together uniquely identify - the operator, and are always strings. For ONNX operators, :attr:`domain` and :attr:`overload` - are both empty strings. - - :attr:`inputs` and :attr:`outputs` are the input and output values of the node. - - :attr:`attributes` are the attributes of the node. The attributes are stored in an - ordered dictionary to preserve the order of the attributes. This is a deviation from - the current ONNX spec where attributes are unordered, but it is helpful for tools - that rely on the order of the attributes, e.g. those converting to and from Python - function keyword arguments. - - :attr:`version` is unique to the IR and is not specified in the ONNX spec. This - allows the IR to represent a graph with mixed opset versions. Deserializers - should decide how to reconcile the different versions within the graph. A typical - graph will have a single version, declared in the :class:`Graph` object and - the nodes will have ``None`` as the version. - - Attributes: - domain: The domain of the operator. E.g. ``""`` for ONNX operators. - op_type: The operator name. - overload: The overload name when the node is invoking a function. - inputs: Input values. - outputs: Output values. - attributes: The attributes of the operator. - version: The version of the operator. - doc_string: Documentation string. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - domain: str - op_type: str - overload: str - inputs: Sequence[ValueProtocol] - outputs: Sequence[ValueProtocol] - attributes: OrderedDict[str, AttributeProtocol | ReferenceAttributeProtocol] - version: int | None - doc_string: str | None - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def replace_input_with(self, index: int, value: ValueProtocol | None) -> None: - """Set the input at the given index to the given value, replacing the original value.""" - ... - - -@typing.runtime_checkable -class GraphProtocol(Protocol): - """Protocol for graphs. - - Graph represents a computation graph. In addition to the ONNX specification - specified fields, it also contains a mapping of :attr:`opset_imports`. This - allows different subgraphs to import different opsets. It is the responsibility - of the deserializer to reconcile the different opsets. - - The nodes are not guaranteed to be topologically sorted. But the - iteration order should be deterministic across different runs. It is the - responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Graph. The Graph can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(graph)``. - - .. :note:: - ``quantization_annotation`` is deserialized into the Value's ``meta`` field - under the ``quant_parameter_tensor_names`` key. Values that are stored - under this key will be serialized as quantization annotations. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - inputs: MutableSequence[ValueProtocol] - outputs: MutableSequence[ValueProtocol] - initializers: MutableMapping[str, ValueProtocol] - doc_string: str - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - - # Mutation methods - def append(self, node: NodeProtocol, /) -> None: - """Append a node to the graph.""" - ... - - def extend(self, nodes: Iterable[NodeProtocol], /) -> None: - """Extend the graph with the given nodes.""" - ... - - def remove(self, node: NodeProtocol, /) -> None: - """Remove a node from the graph.""" - ... - - def insert_after( - self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / - ) -> None: - """Insert new nodes after the given node.""" - ... - - def insert_before( - self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / - ) -> None: - """Insert new nodes before the given node.""" - ... - - def sort(self) -> None: - """Topologically sort the nodes in the graph.""" - ... - - -@typing.runtime_checkable -class GraphViewProtocol(Protocol): - """Protocol for a read-only view on a graph. - - The GraphView is useful for analysis of a subgraph. It can be initialized - with a subset of nodes from a :class:`Graph`. Creating GraphView does not - change the ownership of the nodes, and so it is possible to create multiple - GraphViews that contain the same nodes. - - Attributes: - name: The name of the graph. - inputs: The input values of the graph. - outputs: The output values of the graph. - initializers: The initializers in the graph. - doc_string: Documentation string. - opset_imports: Opsets imported by the graph. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str | None - inputs: Sequence[ValueProtocol] - outputs: Sequence[ValueProtocol] - initializers: Mapping[str, ValueProtocol] - doc_string: str - opset_imports: Mapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - - -@typing.runtime_checkable -class ModelProtocol(Protocol): - """Protocol for models. - - A model is a container for a graph and metadata. It is the top-level object - that represents an ONNX model. - - Attributes: - graph: The graph of the model. - ir_version: The version of the IR. - producer_name: The name of the producer. - producer_version: The version of the producer. - domain: The domain of the model. - model_version: The version of the model. - doc_string: Documentation string. - functions: The functions defined in the model. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - graph: GraphProtocol - ir_version: int - producer_name: str | None - producer_version: str | None - domain: str | None - model_version: int | None - doc_string: str | None - functions: MutableMapping[str, FunctionProtocol] - # TODO(justinchuby): Add training_info - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - -@typing.runtime_checkable -class AttributeProtocol(Protocol): - """Protocol for ONNX attributes. - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - value: The value of the attribute. - doc_string: Documentation string. - """ - - name: str - type: _enums.AttributeType - value: Any - doc_string: str | None - - -@typing.runtime_checkable -class ReferenceAttributeProtocol(Protocol): - """Protocol for a reference attribute. - - A reference attribute can only appear inside the definition body of a function. - - Attributes: - name: The name of the attribute. - ref_attr_name: The name of the attribute definition this attribute refers to. - type: The type of the attribute. - doc_string: Documentation string. - """ - - name: str - ref_attr_name: str - type: _enums.AttributeType - doc_string: str | None - - -@typing.runtime_checkable -class SparseTensorProtocol(Protocol): - values: TensorProtocol - indices: TensorProtocol - dims: Sequence[int] - - -@typing.runtime_checkable -class SymbolicDimProtocol(Protocol): - """Value of a single symbolic/dynamic dimension in a shape. - - Attributes: - value: The value of the dimension. - """ - - value: str | None # TODO(justinchuby): Maybe support sympy - - -@typing.runtime_checkable -class ShapeProtocol(Protocol): - """Protocol for ONNX shapes. - - A shape is a sequence of dimensions. - - Attributes: - dims: The dimensions of the shape. - """ - - dims: Sequence[int | SymbolicDimProtocol] - - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[int | SymbolicDimProtocol]: ... - @typing.overload - def __getitem__(self, index: int) -> int | SymbolicDimProtocol: ... - @typing.overload - def __getitem__(self, index: slice) -> tuple[int | SymbolicDimProtocol, ...]: ... - def __setitem__( - self, index: int, value: int | SymbolicDimProtocol | str | None - ) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, value: object) -> bool: ... - def get_denotation(self, index: int) -> str | None: ... - def set_denotation(self, index: int, denotation: str | None) -> None: ... - def numpy(self) -> Sequence[int]: ... - def rank(self) -> int: ... - - -@typing.runtime_checkable -class TypeProtocol(Protocol): - """Protocol for ONNX tensors, Sequence tensors, Optional tensors and Sparse tensors. - - These three types of tensors share the same attribute "elem_type" so they are - merged in the same interface. Unlike the ONNX TensorProto, shapes are not included - in the type and should be stored in the :class:`Value`. - - Attributes: - denotation: An optional denotation can be used to denote the whole - type with a standard semantic description as to what is - stored inside. - Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition - for pre-defined type denotations. - elem_type: The type of its elements for nested types like Sequence[Optional] tensors. - Or the DataType if the type is not nested. - dtype: The data type of the tensor or the nested tensor. - """ - - denotation: str | None - elem_type: TypeProtocol | _enums.DataType - dtype: _enums.DataType - - def __eq__(self, value: object, /) -> bool: ... - - -@typing.runtime_checkable -class MapTypeProtocol(Protocol): - """Protocol for ONNX map types. - - TODO: This protocol is not yet implemented in the ONNX IR. - """ - - key_type: typing.Literal[ - _enums.DataType.STRING, - _enums.DataType.INT64, - _enums.DataType.INT32, - _enums.DataType.INT16, - _enums.DataType.INT8, - _enums.DataType.UINT64, - _enums.DataType.UINT32, - _enums.DataType.UINT16, - _enums.DataType.UINT8, - ] - value_type: _enums.DataType - - -@typing.runtime_checkable -class FunctionProtocol(Protocol): - """Protocol for ONNX functions. - - Like a graph, a function can have nodes that are not topologically sorted. It is - the responsibility of the user to maintain a topological order of the nodes. - - Note that there is not a ``node`` attribute in the Function. The Function can be - seen as a Sequence of nodes and should be used as such. For example, to obtain - all nodes as a list, call ``list(function)``. - - Attributes: - name: The function name. - domain: The domain this function is defined in. - overload: The overload name when the function is overloaded. - inputs: The input values of the function. - attributes: The attributes this function defines. - outputs: The output values of the function. - opset_imports: Opsets imported by the function. - doc_string: Documentation string. - metadata_props: Metadata that will be serialized to the ONNX file. - meta: Metadata store for graph transform passes. - """ - - name: str - domain: str - overload: str - inputs: Sequence[ValueProtocol] - attributes: OrderedDict[str, AttributeProtocol] - outputs: Sequence[ValueProtocol] - doc_string: str - opset_imports: MutableMapping[str, int] - metadata_props: MutableMapping[str, str] - meta: MutableMapping[str, Any] - - def __getitem__(self, index: int) -> NodeProtocol: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[NodeProtocol]: ... - def __reversed__(self) -> Iterator[NodeProtocol]: ... - def identifier(self) -> OperatorIdentifier: - """Return the unique identifier of the function.""" - ... - - # Mutation methods - # End Block - def append(self, node: NodeProtocol, /) -> None: - """Append a node to the function.""" - ... - - def extend(self, nodes: Iterable[NodeProtocol], /) -> None: - """Extend the function with the given nodes.""" - ... - - def remove(self, node: NodeProtocol, /) -> None: - """Remove a node from the function.""" - ... - - def insert_after( - self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / - ) -> None: - """Insert new nodes after the given node.""" - ... - - def insert_before( - self, node: NodeProtocol, new_nodes: Iterator[NodeProtocol] | NodeProtocol, / - ) -> None: - """Insert new nodes before the given node.""" - ... - - def sort(self) -> None: - """Topologically sort the nodes in the function.""" - ... diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py deleted file mode 100644 index fbcfcb428a..0000000000 --- a/onnxscript/ir/_tape.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Convenience methods for constructing the IR.""" - -from __future__ import annotations - -from typing import ( - Any, - Mapping, - Optional, - Sequence, - Tuple, -) - -from onnxscript import ir -from onnxscript.ir import _convenience - -# A type representing the domains/versions used in creating nodes in IR. -UsedOpsets = set[Tuple[str, Optional[int]]] - - -class Tape: - """Tape class. - - A tape is a recorder that collects nodes and initializers that are created so - that they can be used for creating a graph. - - Example:: - - from onnxscript import ir - - tape = ir.tape.Tape() - a = tape.initializer(ir.tensor([1, 2, 3], name="a")) - b: ir.Value = ... - c: ir.Value = ... - x = tape.op("Add", [a, b], attributes={"alpha": 1.0}) - y = tape.op("Mul", [x, c], attributes={"beta": 2.0}) - model = ir.Model( - graph := ir.Graph( - inputs=[b, c], - outputs=[y], - nodes=tape.nodes, - initializers=tape.initializers - opset_imports={"": 20}, - ), - ir_version=10, - ) - - Attributes: - graph_like: The graph to append the new nodes and initializers to. When - it is None, the nodes and initializers are creating without owned by a graph. - Initializers will not be added to functions because it is not supported by ONNX. - """ - - def __init__(self, graph_like: ir.Graph | ir.Function | None = None) -> None: - self._nodes: list[ir.Node] = [] - self._initializers: list[ir.Value] = [] - self._used_opsets: UsedOpsets = set() - self.graph_like = graph_like - - def __repr__(self) -> str: - return f"Tape(nodes={self._nodes}, initializers={self._initializers})" - - @property - def nodes(self) -> Sequence[ir.Node]: - return tuple(self._nodes) - - @property - def initializers(self) -> Sequence[ir.Value]: - return tuple(self._initializers) - - @property - def used_opsets(self) -> UsedOpsets: - return self._used_opsets - - def op( - self, - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - *, - domain: str = "", - overload: str = "", - version: int | None = None, - graph: ir.Graph | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - output: ir.Value | None = None, - ) -> ir.Value: - if attributes is None: - attrs: Sequence[ir.Attr | ir.RefAttr] = () - else: - attrs = _convenience.convert_attributes(attributes) - output_kwargs: dict[str, Any] - if output is None: - output_kwargs = dict(num_outputs=1) - else: - output_kwargs = dict(outputs=[output]) - node = ir.Node( - domain, - op_type, - inputs, - attributes=attrs, - **output_kwargs, - overload=overload, - version=version, - graph=graph or self.graph_like, - name=name, - doc_string=doc_string, - metadata_props=metadata_props, - ) - self._nodes.append(node) - self._used_opsets.add((domain, version)) - - return node.outputs[0] - - def op_multi_out( - self, - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - *, - num_outputs: int | None = None, - outputs: Sequence[ir.Value] | None = None, - domain: str = "", - overload: str = "", - version: int | None = None, - graph: ir.Graph | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, - ) -> Sequence[ir.Value]: - if num_outputs is None and outputs is None: - raise ValueError("Either num_outputs or outputs must be provided.") - if num_outputs is not None and outputs is not None: - raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.") - output_kwargs: dict[str, Any] - if outputs is None: - output_kwargs = dict(num_outputs=num_outputs) - else: - output_kwargs = dict(outputs=outputs) - if attributes is None: - attrs: Sequence[ir.Attr | ir.RefAttr] = () - else: - attrs = _convenience.convert_attributes(attributes) - node = ir.Node( - domain, - op_type, - inputs, - attributes=attrs, - **output_kwargs, - overload=overload, - version=version, - graph=graph or self.graph_like, - name=name, - doc_string=doc_string, - metadata_props=metadata_props, - ) - self._nodes.append(node) - self._used_opsets.add((domain, version)) - - return node.outputs - - def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: - name = name or tensor.name - if name is None: - raise ValueError("Name must be provided for initializer.") - shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) - value = ir.Value( - name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor - ) - self._initializers.append(value) - if isinstance(self.graph_like, ir.Graph): - self.graph_like.register_initializer(value) - return value - - -class Builder(Tape): - """An extension of the tape that provides a more convenient API for constructing the IR.""" - - def __getattr__(self, op_type: str) -> Any: - return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) - - def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): - domain = kwargs.pop("_domain", "") - version = kwargs.pop("_version", None) - outputs = kwargs.pop("_outputs", 1) - if isinstance(outputs, Sequence): - num_outputs = len(outputs) - else: - assert isinstance(outputs, int) - num_outputs = outputs - - if num_outputs == 1: - value = super().op( - op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version - ) - if isinstance(outputs, Sequence): - value.name = outputs[0] - return value - values = super().op_multi_out( - op_type, - inputs=inputs, - attributes=kwargs, - domain=domain, - version=version, - num_outputs=num_outputs, - ) - if isinstance(outputs, Sequence): - for value, name in zip(values, outputs): - value.name = name - return values diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py deleted file mode 100644 index 46cbcc23fe..0000000000 --- a/onnxscript/ir/_tape_test.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -from onnxscript import ir - - -class TestTape(unittest.TestCase): - def test_op(self): - # Create a simple ONNX model with shape inference - # Define the model - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ] - - tape = ir.tape.Tape() - - _ = tape.op("Add", inputs=inputs) - - self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) - - def test_initializers(self): - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape((2, 1)), - const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), - ), - ] - - tape = ir.tape.Tape() - - # Shape and type are not explicitly set for the initializer but it should still work - initializer = tape.initializer( - ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" - ) - val_add = tape.op("Add", inputs=inputs) - _ = tape.op("Mul", inputs=[val_add, initializer]) - - self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) - self.assertEqual(tape.initializers, (initializer,)) - - def test_op_multi_out(self): - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape((2, 1)), - const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), - ), - ] - - tape = ir.tape.Tape() - - out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking - _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) - - self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py deleted file mode 100644 index 20bab69037..0000000000 --- a/onnxscript/ir/_type_casting.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Numpy utilities for non-native type operation.""" -# TODO(justinchuby): Upstream the logic to onnx - -from __future__ import annotations - -import typing -from typing import Sequence - -import ml_dtypes -import numpy as np - -if typing.TYPE_CHECKING: - import numpy.typing as npt - - -def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]: - """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range.""" - # Create a 1D copy - array_flat = array.ravel().view(np.uint8).copy() - size = array.size - odd_sized = size % 2 == 1 - if odd_sized: - array_flat.resize([size + 1], refcheck=False) - array_flat &= 0x0F - array_flat[1::2] <<= 4 - return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type] - - -def _unpack_uint4_as_uint8( - data: npt.NDArray[np.uint8], dims: Sequence[int] -) -> npt.NDArray[np.uint8]: - """Convert a packed uint4 array to unpacked uint4 array represented as uint8. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of int8/uint8 reshaped to dims. - """ - result = np.empty([data.size * 2], dtype=data.dtype) - array_low = data & np.uint8(0x0F) - array_high = data & np.uint8(0xF0) - array_high >>= np.uint8(4) - result[0::2] = array_low - result[1::2] = array_high - if result.size == np.prod(dims) + 1: - # handle single-element padding due to odd number of elements - result = result[:-1] - result.resize(dims, refcheck=False) - return result - - -def unpack_uint4( - data: npt.NDArray[np.uint8], dims: Sequence[int] -) -> npt.NDArray[ml_dtypes.uint4]: - """Convert a packed uint4 array to unpacked uint4 array represented as uint8. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of int8/uint8 reshaped to dims. - """ - return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4) - - -def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]: - """Extend 4-bit signed integer to 8-bit signed integer.""" - return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8) - - -def unpack_int4( - data: npt.NDArray[np.uint8], dims: Sequence[int] -) -> npt.NDArray[ml_dtypes.int4]: - """Convert a packed (signed) int4 array to unpacked int4 array represented as int8. - - The sign bit is extended to the most significant bit of the int8. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of int8 reshaped to dims. - """ - unpacked = _unpack_uint4_as_uint8(data, dims) - return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4) - - -def unpack_float4e2m1( - data: npt.NDArray[np.uint8], dims: Sequence[int] -) -> npt.NDArray[ml_dtypes.float4_e2m1fn]: - """Convert a packed float4e2m1 array to unpacked float4e2m1 array. - - Args: - data: A numpy array. - dims: The dimensions are used to reshape the unpacked buffer. - - Returns: - A numpy array of float32 reshaped to dims. - """ - return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn) diff --git a/onnxscript/ir/_type_casting_test.py b/onnxscript/ir/_type_casting_test.py deleted file mode 100644 index abe4923eea..0000000000 --- a/onnxscript/ir/_type_casting_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest - -import numpy as np -import parameterized - -from onnxscript.ir import _type_casting - - -class TypeCastingTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("signed", np.int8), - ("unsigned", np.uint8), - ] - ) - def test_pack_int4_even_sized_array(self, _: str, dtype): - array = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype) - expected = np.array([0x21, 0x43, 0x65, 0x87], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - @parameterized.parameterized.expand( - [ - ("signed", np.int8), - ("unsigned", np.uint8), - ] - ) - def test_pack_int4_odd_sized_array(self, _: str, dtype): - array = np.array([1, 2, 3, 4, 5], dtype=dtype) - expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - @parameterized.parameterized.expand( - [ - ("signed", np.int8), - ("unsigned", np.uint8), - ] - ) - def test_pack_int4_returns_flatten_array(self, _: str, dtype): - array = np.array([[[1, 2, 3, 4, 5]]], dtype=dtype) - expected = np.array([0x21, 0x43, 0x5], dtype=np.uint8) - actual = _type_casting.pack_int4(array) - np.testing.assert_array_equal(actual, expected) - - -if __name__ == "__main__": - unittest.main(verbosity=2) diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py deleted file mode 100644 index 480ff603b0..0000000000 --- a/onnxscript/ir/convenience.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Convenience methods for constructing and manipulating the IR.""" - -from __future__ import annotations - -__all__ = [ - "convert_attribute", - "convert_attributes", - "replace_all_uses_with", - "replace_nodes_and_values", - "create_value_mapping", -] - -from onnxscript.ir._convenience import ( - convert_attribute, - convert_attributes, - create_value_mapping, - replace_all_uses_with, - replace_nodes_and_values, -) - -# NOTE: Do not implement any other functions in this module. -# implement them in the _convenience module and import them here instead. - - -def __set_module() -> None: - """Set the module of all functions in this module to this public module.""" - global_dict = globals() - for name in __all__: - global_dict[name].__module__ = __name__ - - -__set_module() diff --git a/onnxscript/ir/external_data.py b/onnxscript/ir/external_data.py deleted file mode 100644 index 4ca9ca5036..0000000000 --- a/onnxscript/ir/external_data.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""External data related utilities.""" - -from __future__ import annotations - -__all__ = [ - "set_base_dir", - "unload_from_model", - "load_to_model", - "convert_tensors_to_external", - "convert_tensors_from_external", -] - -import dataclasses -import logging -import os -from typing import Iterator, Sequence - -from onnxscript.ir import _core, _enums, _protocols -from onnxscript.ir import traversal as _traversal -from onnxscript.ir._polyfill import zip - -# Note: If needed in future, add these as parameters to the function calls -# align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold -_ALIGN_OFFSET = True -# align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned. -_ALIGN_THRESHOLD = 1048576 # 1MB -# allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes. -_ALLOCATION_GRANULARITY = 65536 # 64KB - - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class _ExternalDataInfo: - """ - A class that stores information about a tensor that is to be stored as external data. - - Attributes: - name: The name of the tensor that is to be stored as external data. - offset: The offset is used to determine where exactly in the file the external data is written to. - length: Stores the size of the tensor. - """ - - name: str | None - offset: int - length: int - - -def _all_tensors( - graph: _core.Graph | _core.GraphView, include_attributes: bool = False -) -> Iterator[_protocols.TensorProtocol]: - """Iterate over all tensors in the graph. - - Args: - graph: The graph to traverse tensors on. - include_attributes: Whether to include tensors in attributes. - - Yields: - Tensors in the graph. - """ - # Yield all tensors in initializers - for value in graph.initializers.values(): - if value.const_value is not None: - yield value.const_value - if not include_attributes: - return - # Look at constant attributes in nodes - for node in _traversal.RecursiveGraphIterator(graph): - for attr in node.attributes.values(): - if isinstance(attr, _core.RefAttr): - continue - if attr.type == _enums.AttributeType.TENSOR and attr.value is not None: - yield attr.value - elif attr.type == _enums.AttributeType.TENSORS and attr.value is not None: - yield from attr.value - - -def set_base_dir(graph: _core.Graph | _core.GraphView, base_dir: str | os.PathLike) -> None: - """Set the base directory for external data in a graph. - - Args: - graph: The graph to traverse tensors on. - base_dir: The base directory. This is the directory where the ONNX file is. - """ - for tensor in _all_tensors(graph, include_attributes=True): - if isinstance(tensor, _core.ExternalTensor): - tensor.base_dir = base_dir - - -def _external_tensor_to_memory_tensor( - tensor: _protocols.TensorProtocol, -) -> _protocols.TensorProtocol: - """Convert an external tensor to an in memory tensor. - - Args: - tensor: An external tensor to load. - base_dir: Path of base directory. - relative_path: Path to which external data is to be stored, relative to the ONNX file. - - Returns: - An ir.Tensor object with the data loaded into memory. - """ - if not isinstance(tensor, _core.ExternalTensor): - raise TypeError(f"Expected ExternalTensor, got {type(tensor)}") - # Copy the data as the .numpy() call references data from a file whose data is eventually modified - tensor_data = tensor.numpy().copy() - tensor.release() - return _core.Tensor(tensor_data, name=tensor.name, dtype=tensor.dtype) - - -def _compute_new_offset( - current_offset: int, - tensor_size: int, - align_offset: bool = _ALIGN_OFFSET, - align_threshold: int = _ALIGN_THRESHOLD, - allocation_granularity: int = _ALLOCATION_GRANULARITY, -) -> int: - """Compute the offset to align the tensor data based on the current offset. - - Args: - current_offset: Current location in the file at which tensor data will be written to. - tensor_size: Size of the tensor data to be written to file. - align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold - align_threshold: Alignment threshold for size of data. Having a low threshold will waste file space for small initializers. Only when tensor's data is > the page_align_threshold it will be force aligned. - allocation_granularity: The allocation Granularity for mmap() support. Typically 64KB for Windows & 4KB for other OSes. - - Returns: - The updated offset value. - """ - if align_offset and tensor_size > align_threshold: - alignment_factor = max(4096, allocation_granularity) - # Align to the next page or alloc granularity - return (current_offset + alignment_factor - 1) // alignment_factor * alignment_factor - return current_offset - - -def _compute_external_data_info( - tensor: _protocols.TensorProtocol, - current_offset: int, -) -> _ExternalDataInfo: - """Capture information about a tensor that is to be stored as external data.""" - tensor_size = tensor.nbytes - # Calculate updated offset and align tensors - current_offset = _compute_new_offset(current_offset, tensor_size) - # Store offset and tensor size as ExternalDataInfo - external_data_info = _ExternalDataInfo( - tensor.name, - current_offset, - tensor_size, - ) - return external_data_info - - -def _write_external_data( - tensors: Sequence[_protocols.TensorProtocol], - external_data_infos: Sequence[_ExternalDataInfo], - file_path: str | os.PathLike, -) -> None: - """Write tensor data to an external file according to information stored in ExternalDataInfo objects. - - Args: - tensors: Tensors to be written as external data. - external_data_infos: External data information stored for each tensor to be written as external data. - file_path: Location to which external data is to be stored. - """ - assert len(tensors) == len(external_data_infos), ( - "Number of tensors and external data infos should match" - ) - with open(file_path, "wb") as data_file: - for tensor, tensor_info in zip(tensors, external_data_infos, strict=True): - current_offset = tensor_info.offset - assert tensor is not None - raw_data = tensor.tobytes() - if isinstance(tensor, _core.ExternalTensor): - tensor.release() - # Pad file to required offset if needed - file_size = data_file.tell() - if current_offset > file_size: - data_file.write(b"\0" * (current_offset - file_size)) - data_file.write(raw_data) - - -def _create_external_tensor( - tensor: _protocols.TensorProtocol, - external_data_info: _ExternalDataInfo, - base_dir: str | os.PathLike, - relative_path: str | os.PathLike, -) -> _core.ExternalTensor: - """Create external tensors from external data information. - - Args: - tensor: Tensor to be converted to external tensor. - external_data_info: External data information stored for the tensor to be written as external data. - base_dir: Path of base directory. - relative_path: Path to which external data is to be stored, relative to the ONNX file. - - Returns: - External tensor created from the information. - """ - return _core.ExternalTensor( - os.path.normpath(relative_path), - external_data_info.offset, - external_data_info.length, - tensor.dtype, # type: ignore[arg-type] - shape=tensor.shape, # type: ignore[arg-type] - name=tensor.name, # type: ignore[arg-type] - base_dir=os.path.normpath(base_dir), - ) - - -def convert_tensors_from_external( - tensors: Sequence[_protocols.TensorProtocol], -) -> list[_protocols.TensorProtocol]: - """Convert a sequence of external tensors to in-memory tensors. - - Args: - tensors: External tensors to be converted to in-memory tensors. - - Returns: - A list of in-memory tensors derived from a list of external tensors. - """ - return [_external_tensor_to_memory_tensor(tensor) for tensor in tensors] - - -def convert_tensors_to_external( - tensors: Sequence[_protocols.TensorProtocol], - base_dir: str | os.PathLike, - relative_path: str | os.PathLike, -) -> list[_core.ExternalTensor]: - """Convert a sequence of any TensorProtocol tensors to external tensors. - - Existing external tensors are loaded to memory if they are referring to the - same file path as the destination path. - - Args: - tensors: Tensors to be converted to external tensors. They can be external tensors themselves. - base_dir: Path of base directory. - relative_path: Path to which external data is to be stored, relative to the ONNX file. - - Returns: - A list of external tensors derived from a list of input tensors. The order - should match the input tensor order. - """ - path = os.path.join(base_dir, relative_path) - - # Check if output path exists. Load pre-existing external data if it does. - if os.path.exists(path): - # Check if any tensor provided is using the destination file - new_tensors = [] - for tensor in tensors: - if ( - isinstance(tensor, _core.ExternalTensor) - and os.path.exists(tensor.path) - and os.path.samefile(path, tensor.path) - ): - # FIXME(shubhambhokare1): If there is a non-initializer tensor that - # is referring to this file, that tensor is now invalid. - # This is a special case we are ok not handling right now. - new_tensors.append(_external_tensor_to_memory_tensor(tensor)) - # Mark the original external tensor as invalid because it is now pointing - # to a file that is going to be overwritten. - tensor.invalidate() - logger.warning( - "External tensor %s is referring to the same file as the destination path. " - "It has been invalidated because the data file is changed. To avoid this, " - "save the external data to a different path or load the newly saved model back " - "with ir.load().", - tensor, - ) - else: - new_tensors.append(tensor) - tensors = new_tensors - - external_data_infos: list[_ExternalDataInfo] = [] - # Sort all tensors based on tensor sizes, in order to avoid unnecessary alignment. - # All the smaller tensors are written earlier and alignment is performed for the larger tensors. - sorted_indices = sorted(range(len(tensors)), key=lambda i: tensors[i].nbytes) - sorted_tensors = [tensors[i] for i in sorted_indices] - - # Compute external data information for each tensor and write to disk - current_offset = 0 - for tensor in sorted_tensors: - external_info = _compute_external_data_info(tensor, current_offset) - external_data_infos.append(external_info) - current_offset = external_info.offset + external_info.length - _write_external_data(sorted_tensors, external_data_infos, path) - - # Create external tensor objects - external_tensors: list[_core.ExternalTensor] = [ - _create_external_tensor(tensor, external_info, base_dir, relative_path) - for tensor, external_info in zip(sorted_tensors, external_data_infos, strict=True) - ] - - # Sort external_tensors based on original key order. So that it can match the input tensor order - external_tensors = [ - external_tensors[i] - for i in sorted(range(len(external_tensors)), key=lambda i: sorted_indices[i]) - ] - - return external_tensors - - -def load_to_model(model: _core.Model) -> _core.Model: - """Convert all external model initializers to memory tensors in-place. - - Args: - model: Model to process. - """ - # TODO(justinchuby): Load attributes and initializers in subgraphs - values_to_convert = [] - for value in model.graph.initializers.values(): - if value.const_value is None: - # Filter out the uninitialized initializer values - continue - if isinstance(value.const_value, _core.ExternalTensor): - values_to_convert.append(value) - loaded_tensors = convert_tensors_from_external( - [v.const_value for v in values_to_convert] # type: ignore[misc] - ) - for value, tensor in zip(values_to_convert, loaded_tensors, strict=True): - value.const_value = tensor - - # Return the model because we may change the implementation to an out of place one - # to keep the input unchanged - return model - - -def unload_from_model( - model: _core.Model, - base_dir: str | os.PathLike, - relative_path: str | os.PathLike, - *, - size_threshold_bytes: int = 0, -) -> _core.Model: - """Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file. - - It should only replace the initializers in the model with external tensors - and not make any other modifications to the model. - - If any existing external tensor - references the provided ``external_data`` path, it will be invalidated - after the external data is overwritten. To obtain a valid model, use :func:`load` - to load the newly saved model, or provide a different external data path that - is not currently referenced by any tensors in the model. - - Args: - model: Model to process. - base_dir: Path the directory where the ONNX model file is. - relative_path: Path to which external data is to be stored, relative to the ONNX file. - E.g. "model.data" - size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold. - - Returns: - An ir.Model with all initializer data equal or above ``size_threshold_bytes`` - converted to external tensors. - """ - # In-memory or external tensors, if equal to or above the threshold, should be converted to or re-saved as external tensors - initializers_to_become_external = [] - # Existing external tensors, if below the threshold, should be loaded to memory - initializers_to_load_to_memory = [] - for value in model.graph.initializers.values(): - if value.const_value is None: - # Filter out the uninitialized initializer values - continue - if value.const_value.nbytes > size_threshold_bytes: - initializers_to_become_external.append(value) - elif isinstance(value.const_value, _core.ExternalTensor): - initializers_to_load_to_memory.append(value) - - # Load to memory first, then convert to external tensors, because - # the existing external tensors may be overwritten by the new external data - memory_tensors = convert_tensors_from_external( - [v.const_value for v in initializers_to_load_to_memory] # type: ignore[misc] - ) - external_tensors = convert_tensors_to_external( - [v.const_value for v in initializers_to_become_external], # type: ignore[misc] - base_dir=base_dir, - relative_path=relative_path, - ) - - # Replace the initializer values with external tensors and save the model - for value, external_tensor in zip( - initializers_to_become_external, external_tensors, strict=True - ): - value.const_value = external_tensor - for value, memory_tensor in zip( - initializers_to_load_to_memory, memory_tensors, strict=True - ): - value.const_value = memory_tensor - - # Return the model because we may change the implementation to an out of place one - # to keep the input unchanged - return model diff --git a/onnxscript/ir/external_data_test.py b/onnxscript/ir/external_data_test.py deleted file mode 100644 index 11de6285c9..0000000000 --- a/onnxscript/ir/external_data_test.py +++ /dev/null @@ -1,502 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import os -import sys -import tempfile -import typing -import unittest - -import numpy as np -import onnx -import onnx.external_data_helper - -from onnxscript import ir -from onnxscript.ir import external_data - - -class ExternalDataTest(unittest.TestCase): - def test_set_base_dir_sets_base_dir_for_all_external_tensors(self): - attr_tensor = onnx.helper.make_tensor( - name="test_constant", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=b"\x01\x00\x00\x00", - raw=True, - ) - graph = onnx.helper.make_graph( - nodes=[ - onnx.helper.make_node( - "Constant", - [], - ["test"], - value=attr_tensor, - ) - ], - name="test", - inputs=[], - outputs=[], - initializer=[ - onnx.helper.make_tensor( - name="test_tensor", - data_type=onnx.TensorProto.FLOAT, - dims=[1], - vals=b"\x01\x00\x00\x00", - raw=True, - ), - ], - ) - model_proto = onnx.helper.make_model(graph) - onnx.external_data_helper.convert_model_to_external_data( - model_proto, location="tempdir", size_threshold=0, convert_attribute=True - ) - model = ir.serde.deserialize_model(model_proto) - expected_dir = "something_else" - external_data.set_base_dir(model.graph, expected_dir) - - initializer_tensor = model.graph.initializers["test_tensor"].const_value - assert isinstance(initializer_tensor, ir.ExternalTensor) - self.assertEqual(initializer_tensor.base_dir, expected_dir) - attr_tensor = model.graph.node(0).attributes["value"].value - self.assertEqual(attr_tensor.base_dir, expected_dir) - - -class OffsetCalcTest(unittest.TestCase): - """Test the offset calculation for the external tensor class.""" - - def test_align_offset_false(self): - # Tensor size > Align Threshold - current_offset = 20000 - tensor_size = 1048 - new_offset = external_data._compute_new_offset( # pylint: disable=protected-access - current_offset, tensor_size, align_offset=False - ) - self.assertEqual(current_offset, new_offset) - - def test_align_with_small_align_threshold(self): - # Tensor size < Align Threshold - current_offset = 20000 - tensor_size = 1048 - new_offset = external_data._compute_new_offset( # pylint: disable=protected-access - current_offset, - tensor_size, - align_threshold=1000, - ) - self.assertNotEqual(current_offset, new_offset) - - def test_align_with_large_align_threshold(self): - # Tensor size > Align Threshold - current_offset = 20000 - tensor_size = 1048 - new_offset = external_data._compute_new_offset( # pylint: disable=protected-access - current_offset, - tensor_size, - ) - self.assertEqual(current_offset, new_offset) - - def test_allocation_granularity_diff(self): - # Tensor size > Align Threshold - current_offset = 20000 - tensor_size = 1048577 - new_offset_1 = external_data._compute_new_offset( # pylint: disable=protected-access - current_offset, - tensor_size, - allocation_granularity=4000, - ) - new_offset_2 = external_data._compute_new_offset( # pylint: disable=protected-access - current_offset, - tensor_size, - ) - self.assertNotEqual(current_offset, new_offset_1) - self.assertNotEqual(current_offset, new_offset_2) - self.assertNotEqual(new_offset_1, new_offset_2) - - -class OffloadExternalTensorTest(unittest.TestCase): - """Test the memory mapped external tensor class.""" - - def setUp(self): - # File paths - if sys.version_info[:2] >= (3, 10): - self.temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) # pylint: disable=consider-using-with - else: - self.temp_dir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with - self.external_data_name = "external_tensors.bin" - self.base_path = self.temp_dir.name - self.ext_data_1 = "external_data_1.bin" - self.ext_data_2 = "external_data_2.bin" - # Data for the tensors - self.data = np.random.rand(2, 42).astype(np.float32) - self.data_other = np.random.rand(2, 42).astype(np.float32) - self.data_float16 = np.random.rand(2, 42).astype(np.float16) - self.data_ext1_1 = np.random.rand(1, 42).astype(np.float32) - self.data_ext1_2 = np.random.rand(4, 42).astype(np.float16) - self.data_ext2_1 = np.random.rand(5, 42).astype(np.float16) - self.custom_data = np.random.rand(3, 42).astype(np.float32) - # Model Assignments - self.model = self._simple_model() - self.model_with_external_data_same_path = self._model_with_external_data_same_path() - self.model_with_external_data_diff_path = self._model_with_external_data_diff_path() - self.model_with_custom_tensor_class = self._model_with_custom_tensor_class() - self.model_with_mixed_external_data = self._model_with_mixed_external_data() - - def tearDown(self) -> None: - # Handle exceptions for windows and python versions < 3.10 - try: - self.temp_dir.cleanup() - except PermissionError as e: - print(f"PermissionError: {e}") - except FileNotFoundError as e: - print(f"FileNotFoundError: {e}") - except Exception as e: # pylint: disable=broad-exception-caught - print(f"An unexpected error occurred: {e}") - - def _simple_model(self) -> ir.Model: - tensor1 = ir.Tensor( - self.data, - dtype=ir.DataType.FLOAT, - shape=ir.Shape(self.data.shape), - name="tensor1", - ) - tensor2 = ir.Tensor( - self.data_float16, - dtype=ir.DataType.FLOAT16, - shape=ir.Shape(self.data_float16.shape), - name="tensor2", - ) - node_0 = ir.Node( - "", - "Op_0", - inputs=[ir.Input("input_0"), ir.Input("input_1")], - num_outputs=2, - name="node_0", - ) - node_1 = ir.Node( - "", - "Op_1", - inputs=[node_0.outputs[0]], - num_outputs=1, - name="node_1", - ) - graph = ir.Graph( - inputs=node_0.inputs, # type: ignore - outputs=[node_1.outputs[0]], - initializers=[ - ir.Value(name="tensor1", const_value=tensor1), - ir.Value(name="tensor2", const_value=tensor2), - ], - # Unsorted nodes - nodes=[node_1, node_0], - name="test_graph", - ) - model = ir.Model(graph, ir_version=8) - return model - - def _setup_custom_tensor_class(self, name, value): - class CustomTensorType(ir.TensorProtocol): - def __init__( - self, - value: np.ndarray, - ): - self.name = name - self._raw = value - if isinstance(value, np.ndarray): - self._dtype = ir._enums.DataType.from_numpy(value.dtype) - self._shape = ir.Shape(getattr(value, "shape"), frozen=True) # noqa: B009 - - @property - def dtype(self) -> ir._enums.DataType: - """The data type of the tensor. Immutable.""" - return self._dtype - - @property - def shape(self) -> ir.Shape: - """The shape of the tensor. Immutable.""" - return self._shape - - @property - def nbytes(self) -> int: - return len(self.tobytes()) - - def __array__(self, dtype: typing.Any = None) -> np.ndarray: - if isinstance(self._raw, np.ndarray): - return self._raw - else: - return TypeError - - def numpy(self) -> np.ndarray: - return self._raw - - def tobytes(self) -> bytes: - if isinstance(self._raw, np.ndarray): - return self._raw.tobytes() - else: - return TypeError - - return CustomTensorType(value) - - def _model_with_external_data_same_path(self) -> ir.Model: - model = self._simple_model() - raw_data = self.data_other.tobytes() - # Save the data to disk - file_path = os.path.join(self.base_path, self.external_data_name) - with open(file_path, "wb") as f: - f.write(raw_data) - tensor_same_file = ir.ExternalTensor( - location=self.external_data_name, - offset=0, - length=len(raw_data), - dtype=ir.DataType.FLOAT, - name="tensor_same_file", - shape=ir.Shape(self.data_other.shape), - base_dir=self.base_path, - ) - model.graph.initializers["tensor_same_file"] = ir.Value( - name="tensor_same_file", const_value=tensor_same_file - ) - return model - - def _model_with_external_data_diff_path(self) -> ir.Model: - model = self._simple_model() - # File 1 - file_path_1 = os.path.join(self.base_path, self.ext_data_1) - with open(file_path_1, "wb") as f: - f.write(self.data_ext1_1.tobytes()) - f.write(self.data_ext1_2.tobytes()) - tensor_ext1_1 = ir.ExternalTensor( - location=self.ext_data_1, - offset=0, - length=len(self.data_ext1_1.tobytes()), - dtype=ir.DataType.FLOAT, - name="tensor_ext1_1", - shape=ir.Shape(self.data_ext1_1.shape), - base_dir=self.base_path, - ) - tensor_ext1_2 = ir.ExternalTensor( - location=self.ext_data_1, - offset=len(self.data_ext1_1.tobytes()), - length=len(self.data_ext1_2.tobytes()), - dtype=ir.DataType.FLOAT16, - name="tensor_ext1_2", - shape=ir.Shape(self.data_ext1_2.shape), - base_dir=self.base_path, - ) - # File 2 - file_path_2 = os.path.join(self.base_path, self.ext_data_2) - with open(file_path_2, "wb") as f: - f.write(self.data_ext2_1.tobytes()) - tensor_ext2_1 = ir.ExternalTensor( - location=self.ext_data_2, - offset=0, - length=len(self.data_ext2_1.tobytes()), - dtype=ir.DataType.FLOAT16, - name="tensor_ext2_1", - shape=ir.Shape(self.data_ext2_1.shape), - base_dir=self.base_path, - ) - model.graph.initializers["tensor_ext1_1"] = ir.Value( - name="tensor_ext1_1", const_value=tensor_ext1_1 - ) - model.graph.initializers["tensor_ext1_2"] = ir.Value( - name="tensor_ext1_2", const_value=tensor_ext1_2 - ) - model.graph.initializers["tensor_ext2_1"] = ir.Value( - name="tensor_ext2_1", const_value=tensor_ext2_1 - ) - return model - - def _model_with_custom_tensor_class(self) -> ir.Model: - model = self._simple_model() - custom_tensor = self._setup_custom_tensor_class("custom_tensor", self.custom_data) - model.graph.initializers["custom_tensor"] = ir.Value( - name="custom_tensor", const_value=custom_tensor - ) - return model - - def _model_with_mixed_external_data(self) -> ir.Model: - model = self._simple_model() - model_same_path = self.model_with_external_data_same_path - model_diff_path = self.model_with_external_data_diff_path - model_custom_tensor = self.model_with_custom_tensor_class - model.graph.initializers["tensor_same_file"] = ir.Value( - name="tensor_same_file", - const_value=model_same_path.graph.initializers["tensor_same_file"].const_value, - ) - model.graph.initializers["tensor_ext1_1"] = ir.Value( - name="tensor_ext1_1", - const_value=model_diff_path.graph.initializers["tensor_ext1_1"].const_value, - ) - model.graph.initializers["tensor_ext1_2"] = ir.Value( - name="tensor_ext1_2", - const_value=model_diff_path.graph.initializers["tensor_ext1_2"].const_value, - ) - model.graph.initializers["tensor_ext2_1"] = ir.Value( - name="tensor_ext2_1", - const_value=model_diff_path.graph.initializers["tensor_ext2_1"].const_value, - ) - model.graph.initializers["custom_tensor"] = ir.Value( - name="custom_tensor", - const_value=model_custom_tensor.graph.initializers["custom_tensor"].const_value, - ) - return model - - def test_external_data_simple(self): - model_with_external_data = external_data.unload_from_model( - self.model, self.base_path, self.external_data_name - ) - external_tensor = model_with_external_data.graph.initializers["tensor1"].const_value - external_tensor2 = model_with_external_data.graph.initializers["tensor2"].const_value - - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - - def test_same_path_external_data(self): - model_with_external_data = external_data.unload_from_model( - self.model_with_external_data_same_path, - self.base_path, - self.external_data_name, - ) - external_tensor = model_with_external_data.graph.initializers["tensor1"].const_value - external_tensor2 = model_with_external_data.graph.initializers["tensor2"].const_value - external_tensor3 = model_with_external_data.graph.initializers[ - "tensor_same_file" - ].const_value - - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_other.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_other.tobytes()) - - def test_external_data_diff_paths(self): - model_with_external_data = external_data.unload_from_model( - self.model_with_external_data_diff_path, - self.base_path, - self.external_data_name, - ) - external_tensor = model_with_external_data.graph.initializers["tensor1"].const_value - external_tensor2 = model_with_external_data.graph.initializers["tensor2"].const_value - external_tensor3 = model_with_external_data.graph.initializers[ - "tensor_ext1_1" - ].const_value - external_tensor4 = model_with_external_data.graph.initializers[ - "tensor_ext1_2" - ].const_value - external_tensor5 = model_with_external_data.graph.initializers[ - "tensor_ext2_1" - ].const_value - - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_ext1_1.tobytes()) - self.assertEqual(external_tensor4.numpy().tobytes(), self.data_ext1_2.tobytes()) - self.assertEqual(external_tensor5.numpy().tobytes(), self.data_ext2_1.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_ext1_1.tobytes()) - self.assertEqual(external_tensor4.numpy().tobytes(), self.data_ext1_2.tobytes()) - self.assertEqual(external_tensor5.numpy().tobytes(), self.data_ext2_1.tobytes()) - - def test_custom_tensor_in_initializers(self): - model_with_external_data = external_data.unload_from_model( - self.model_with_custom_tensor_class, - self.base_path, - self.external_data_name, - ) - external_tensor = model_with_external_data.graph.initializers["tensor1"].const_value - external_tensor2 = model_with_external_data.graph.initializers["tensor2"].const_value - external_tensor3 = model_with_external_data.graph.initializers[ - "custom_tensor" - ].const_value - - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.custom_data.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.custom_data.tobytes()) - - def test_mixed_external_data(self): - model_with_external_data = external_data.unload_from_model( - self.model_with_mixed_external_data, self.base_path, self.external_data_name - ) - external_tensor = model_with_external_data.graph.initializers["tensor1"].const_value - external_tensor2 = model_with_external_data.graph.initializers["tensor2"].const_value - external_tensor3 = model_with_external_data.graph.initializers[ - "tensor_same_file" - ].const_value - external_tensor4 = model_with_external_data.graph.initializers[ - "custom_tensor" - ].const_value - external_tensor5 = model_with_external_data.graph.initializers[ - "tensor_ext1_1" - ].const_value - external_tensor6 = model_with_external_data.graph.initializers[ - "tensor_ext1_2" - ].const_value - external_tensor7 = model_with_external_data.graph.initializers[ - "tensor_ext2_1" - ].const_value - - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_other.tobytes()) - self.assertEqual(external_tensor4.numpy().tobytes(), self.custom_data.tobytes()) - self.assertEqual(external_tensor5.numpy().tobytes(), self.data_ext1_1.tobytes()) - self.assertEqual(external_tensor6.numpy().tobytes(), self.data_ext1_2.tobytes()) - self.assertEqual(external_tensor7.numpy().tobytes(), self.data_ext2_1.tobytes()) - # Ensure repeated reads are consistent - self.assertEqual(external_tensor.numpy().tobytes(), self.data.tobytes()) - self.assertEqual(external_tensor2.numpy().tobytes(), self.data_float16.tobytes()) - self.assertEqual(external_tensor3.numpy().tobytes(), self.data_other.tobytes()) - self.assertEqual(external_tensor4.numpy().tobytes(), self.custom_data.tobytes()) - self.assertEqual(external_tensor5.numpy().tobytes(), self.data_ext1_1.tobytes()) - self.assertEqual(external_tensor6.numpy().tobytes(), self.data_ext1_2.tobytes()) - self.assertEqual(external_tensor7.numpy().tobytes(), self.data_ext2_1.tobytes()) - - def test_external_data_sorted(self): - model_with_external_data = external_data.unload_from_model( - self.model_with_mixed_external_data, - self.base_path, - self.external_data_name, - ) - file_path = os.path.join(self.base_path, self.external_data_name) - expected_tensor_order = [ - model_with_external_data.graph.initializers["tensor2"].const_value.tobytes(), - model_with_external_data.graph.initializers["tensor_ext1_1"].const_value.tobytes(), - model_with_external_data.graph.initializers["tensor1"].const_value.tobytes(), - model_with_external_data.graph.initializers[ - "tensor_same_file" - ].const_value.tobytes(), - model_with_external_data.graph.initializers["tensor_ext1_2"].const_value.tobytes(), - model_with_external_data.graph.initializers["tensor_ext2_1"].const_value.tobytes(), - model_with_external_data.graph.initializers["custom_tensor"].const_value.tobytes(), - ] - sorted_tensor_order = [ - self.data_float16.tobytes(), - self.data_ext1_1.tobytes(), - self.data.tobytes(), - self.data_other.tobytes(), - self.data_ext1_2.tobytes(), - self.data_ext2_1.tobytes(), - self.custom_data.tobytes(), - ] - with open(file_path, "r+b") as data_file: - current_offset = 0 - for i, tensor_bytes in enumerate(sorted_tensor_order): - data_file.seek(current_offset) - tensor_length = len(tensor_bytes) - tensor_data = data_file.read(tensor_length) - current_offset += tensor_length - self.assertEqual(tensor_data, tensor_bytes) - self.assertEqual(tensor_data, expected_tensor_order[i]) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/__init__.py b/onnxscript/ir/passes/__init__.py index 8a18c1b72f..5310f1740a 100644 --- a/onnxscript/ir/passes/__init__.py +++ b/onnxscript/ir/passes/__init__.py @@ -15,7 +15,7 @@ "PassError", ] -from onnxscript.ir.passes._pass_infra import ( +from onnx_ir.passes import ( FunctionalPass, InPlacePass, InvariantError, @@ -27,13 +27,3 @@ PreconditionError, Sequential, ) - - -def __set_module() -> None: - """Set the module of all functions in this module to this public module.""" - global_dict = globals() - for name in __all__: - global_dict[name].__module__ = __name__ - - -__set_module() diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py deleted file mode 100644 index 18e5c8715b..0000000000 --- a/onnxscript/ir/passes/_pass_infra.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# -# This module implements some APIs described in -# https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html -# for the ONNX IR. -# The classes {PassResult and PassManager} are derived from -# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_base.py#L12 -# and -# https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_manager.py#L147 -# The original code is licensed under the PyTorch License https://github.com/pytorch/pytorch/blob/main/LICENSE - -"""Passes infrastructure for the IR.""" - -from __future__ import annotations - -import dataclasses -import logging -from typing import Literal, Sequence, final - -__all__ = [ - "PassBase", - "Sequential", - "InPlacePass", - "FunctionalPass", - "PassManager", - "PassResult", - # Errors - "InvariantError", - "PreconditionError", - "PostconditionError", - "PassError", -] - -import abc - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -class InvariantError(Exception): - """Raised when an invariant is violated.""" - - -class PreconditionError(InvariantError): - """Raised when a precondition is violated.""" - - -class PostconditionError(InvariantError): - """Raised when a postcondition is violated.""" - - -class PassError(RuntimeError): - """Raised when an error occurs during a pass.""" - - -@dataclasses.dataclass -class PassResult: - """Result of a pass. - - Attributes: - model: The transformed model. - modified: Whether the resulting model is different from the input model. - """ - - model: ir.Model - modified: bool - - -class PassBase(abc.ABC): - """Base class for all passes. - - - ``in_place`` and ``changes_input`` properties and what they mean: - - +------------+------------------+----------------------------+ - | | changes_inputs | not changes_inputs | - +------------+------------------+----------------------------+ - | in_place | in place | Side-effect-only pass | - +------------+------------------+----------------------------+ - | not | destructive | functional | - | in_place | | | - +------------+------------------+----------------------------+ - """ - - @property - @abc.abstractmethod - def in_place(self) -> bool: - """Whether the pass modifies the model in place and returns it. - - If True, the pass will return the same model object that was passed in. - If False, the pass will return a new model object. - """ - raise NotImplementedError - - @property - @abc.abstractmethod - def changes_input(self) -> bool: - """Whether the pass modifies input model.""" - raise NotImplementedError - - @property - def destructive(self) -> bool: - """Whether the pass will destroy the input model when ``in_place=False``. - - A pass is destructive if it is not in place and it modifies the input model. - """ - return not self.in_place and self.changes_input - - def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult: - if isinstance(model_or_result, PassResult): - model = model_or_result.model - else: - model = model_or_result - # Check preconditions - try: - self.requires(model) - except PreconditionError: - raise - except Exception as e: - raise PreconditionError( - f"Pre-condition for pass '{self.__class__.__name__}' failed" - ) from e - - result = self.call(model) - - # Check postconditions - try: - self.ensures(model) - except PostconditionError: - raise - except Exception as e: - raise PostconditionError( - f"Post-condition for pass '{self.__class__.__name__}' failed" - ) from e - - if not isinstance(result, PassResult): - raise TypeError( - f"The result of the pass '{self.__class__.__name__}' should be type PassResult. " - "Please create one with ir.passes.PassResult()." - ) - - # Checks that the declared in-place property is respected - if self.in_place and result.model is not model: - raise PassError( - f"The pass '{self.__class__.__name__}' is declared in-place, " - "but the model returned is *not* the same object as the input model. " - "Pass developer: Pass should return the same model object or the in_place property should return False." - ) - if not self.in_place and result.model is model: - raise PassError( - f"The pass '{self.__class__.__name__}' is declared not in-place, " - "but the model returned *is* the same object as the input model. " - "Pass developer: Pass should return a new model object or the in_place property should return True." - ) - return result - - @abc.abstractmethod - def call(self, model: ir.Model) -> PassResult: - """The main entry point for the pass.""" - ... - - def requires(self, model: ir.Model) -> None: - """Pre-conditions for the pass. - - This is optional to implement, will be called before call() if run by a pass manager. - """ - del model # Unused - - def ensures(self, model: ir.Model) -> None: - """Post-conditions for the pass. - - This is optional to implement, will be called after call() if run by a pass manager. - """ - del model # Unused - - -class InPlacePass(PassBase): - """A pass that modifies the input model in place and returns it.""" - - @property - @final - def in_place(self) -> Literal[True]: - """An in-place pass is in place.""" - return True - - @property - @final - def changes_input(self) -> Literal[True]: - """An in-place pass changes the input model.""" - return True - - -class FunctionalPass(PassBase): - """A pass that returns a new model but does not modify the input model.""" - - @property - @final - def in_place(self) -> Literal[False]: - """A functional pass is not in place.""" - return False - - @property - @final - def changes_input(self) -> Literal[False]: - """A functional pass does not change the input model.""" - return False - - -class Sequential(PassBase): - """Run a sequence of passes in order.""" - - def __init__(self, *passes: PassBase): - if not passes: - raise ValueError("Sequential must take at least one pass") - self.passes = passes - self._in_place = all(pass_.in_place for pass_ in passes) - # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place, - # or if it is not designed to be in-place but somehow changes the input (destructive), - # this pass sequence will change inputs. - self._changes_input = self.passes[0].changes_input or self.passes[0].in_place - - @property - def in_place(self) -> bool: - return self._in_place - - @property - def changes_input(self) -> bool: - return self._changes_input - - def call(self, model: ir.Model) -> PassResult: - modified = False - for i, pass_ in enumerate(self.passes): - logger.debug("Running the %s-th pass '%s'", i, pass_) - try: - pass_result = pass_(model) - except Exception as e: - prev_pass_names = [str(p) for p in self.passes[:i]] - raise PassError( - f"An error occurred when running the '{pass_}' pass after the " - f"following passes: {prev_pass_names}" - ) from e - - model = pass_result.model - modified = modified or pass_result.modified - - return PassResult(model, modified) - - -class PassManager(Sequential): - """Pass manager for the IR. - - The PassManager is a Pass that runs a sequence of passes on a model. - - Attributes: - passes: The passes to run. - steps: The number of times to run the passes. - early_stop: Whether to stop running the passes if the graph stops changing. - """ - - def __init__( - self, - passes: Sequence[PassBase], - steps: int = 1, - early_stop: bool = True, - ): - # TODO(justinchuby): Implement constraints - super().__init__(*passes) - self.steps = steps - self.early_stop = early_stop - - def call(self, model: ir.Model) -> PassResult: - """Run the set of passes `steps` number of times or until the graph stops changing.""" - overall_modified = False - for step in range(self.steps): - try: - # Call the call method of Sequential - step_result = super().call(model) - except Exception as e: - raise PassError(f"An error occurred at step {step}") from e - model = step_result.model - modified = step_result.modified - overall_modified = overall_modified or modified - # If the graph no longer changes, then we can stop running these passes - if not modified and self.early_stop: - logger.info("PassManager: No more graph changes detected after step %s", step) - break - return PassResult(model, overall_modified) diff --git a/onnxscript/ir/passes/_pass_infra_test.py b/onnxscript/ir/passes/_pass_infra_test.py deleted file mode 100644 index 7f916baebf..0000000000 --- a/onnxscript/ir/passes/_pass_infra_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from __future__ import annotations - -import unittest - -from onnxscript import ir -from onnxscript.ir.passes import _pass_infra - - -class PassBaseTest(unittest.TestCase): - def test_pass_results_can_be_used_as_pass_input(self): - class TestPass(_pass_infra.PassBase): - @property - def in_place(self) -> bool: - return True - - @property - def changes_input(self) -> bool: - return False - - def call(self, model: ir.Model) -> _pass_infra.PassResult: - # This is a no-op pass - return _pass_infra.PassResult(model=model, modified=False) - - pass_ = TestPass() - model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10) - result = pass_(model) - self.assertIsInstance(result, _pass_infra.PassResult) - # pass can take the result of another pass as input - result_1 = pass_(result) - # It can also take the model as input - result_2 = pass_(result.model) - self.assertIs(result_1.model, result_2.model) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py index d1b4f176a2..34931c924f 100644 --- a/onnxscript/ir/passes/common/__init__.py +++ b/onnxscript/ir/passes/common/__init__.py @@ -16,21 +16,17 @@ "TopologicalSortPass", ] -from onnxscript.ir.passes.common.clear_metadata_and_docstring import ( - ClearMetadataAndDocStringPass, -) -from onnxscript.ir.passes.common.constant_manipulation import ( +from onnx_ir.passes.common import ( AddInitializersToInputsPass, + CheckerPass, + ClearMetadataAndDocStringPass, + InlinePass, LiftConstantsToInitializersPass, LiftSubgraphInitializersToMainGraphPass, RemoveInitializersFromInputsPass, -) -from onnxscript.ir.passes.common.inliner import InlinePass -from onnxscript.ir.passes.common.onnx_checker import CheckerPass -from onnxscript.ir.passes.common.shape_inference import ShapeInferencePass -from onnxscript.ir.passes.common.topological_sort import TopologicalSortPass -from onnxscript.ir.passes.common.unused_removal import ( RemoveUnusedFunctionsPass, RemoveUnusedNodesPass, RemoveUnusedOpsetsPass, + ShapeInferencePass, + TopologicalSortPass, ) diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py deleted file mode 100644 index bb2715c75c..0000000000 --- a/onnxscript/ir/passes/common/_c_api_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Utilities for interfacing with onnx C APIs.""" - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING, Callable, TypeVar - -from onnxscript import ir - -if TYPE_CHECKING: - import onnx - - -logger = logging.getLogger(__name__) -# Temporarily remove initializers larger than this size to keep model size down -# for the onnx.shape_inference call because it needs to serialize the model -_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB -_R = TypeVar("_R") - - -def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: - """Call an ONNX C API function by temporarily removing initializers. - - This is necessary because the ONNX C API does not support large models - with initializers that have large tensor values. The input model is left - unchanged no matter the call succeeds or not. - - Args: - func: Partially applied function that takes a model proto and returns anything. - model: The IR model to pass to the API function. - - Returns: - The resulting ModelProto that contains the result of the API call. - """ - - # Store the original initializer values so they can be restored - initializer_values = tuple(model.graph.initializers.values()) - tensors = {v.name: v.const_value for v in initializer_values} - original_inputs_len = len(model.graph.inputs) - - # Turn the initializers into inputs and clear the initializers - # to limit the model size - for initializer in initializer_values: - # Make sure the initializer has its shape/type set - assert initializer.const_value is not None - if initializer.shape is None: - initializer.shape = initializer.const_value.shape # type: ignore[assignment] - if initializer.dtype is None: - initializer.dtype = initializer.const_value.dtype - if initializer not in model.graph.inputs: - model.graph.inputs.append(initializer) - if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: - # Temporarily remove the initializer value to reduce model size - # for onnx.shape_inference - initializer.const_value = None - assert initializer.name is not None - model.graph.initializers.pop(initializer.name) - - proto = ir.serde.serialize_model(model) - - try: - # Call the ONNX C API function - result = func(proto) - finally: - # Restore the original initializer values so the model is unchanged - for initializer in initializer_values: - initializer.const_value = tensors[initializer.name] - model.graph.register_initializer(initializer) - - # Restore the original inputs - inputs = model.graph.inputs[:original_inputs_len] - model.graph.inputs.clear() - model.graph.inputs.extend(inputs) - - return result diff --git a/onnxscript/ir/passes/common/clear_metadata_and_docstring.py b/onnxscript/ir/passes/common/clear_metadata_and_docstring.py deleted file mode 100644 index 0c1fa48cb0..0000000000 --- a/onnxscript/ir/passes/common/clear_metadata_and_docstring.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Clear all metadata and docstring from the model, graphs, nodes, and functions.""" - -from __future__ import annotations - -__all__ = [ - "ClearMetadataAndDocStringPass", -] - -import logging - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -class ClearMetadataAndDocStringPass(ir.passes.InPlacePass): - """Clear all metadata and docstring from the model, graphs, nodes, and functions.""" - - def call(self, model: ir.Model) -> ir.passes.PassResult: - # 0. TODO: Should we clean model metadata and docstring? - - # 1. Clean up the graph and the belonged nodes metadata properties - modified = self._clear_graph_or_function_metadata_and_docstring(model.graph) - - # 2. Clean up all of the functions metadata properties - for function in model.functions.values(): - modified = ( - self._clear_graph_or_function_metadata_and_docstring(function) or modified - ) - return ir.passes.PassResult(model, modified=modified) - - def _clear_graph_or_function_metadata_and_docstring( - self, - graph_or_function: ir.Graph | ir.Function, - ) -> bool: - """Clear metadata and docstring from the graph or function.""" - checked_graphs_or_functions: set[ir.Graph | ir.Function] = set() - modified = False - # Clean up all of the nodes metadata properties - for node in ir.traversal.RecursiveGraphIterator(graph_or_function): - if node.metadata_props: - modified = True - logger.debug("Removed metadata from %s nodes", node.name) - node.metadata_props.clear() - node.doc_string = None - - # Clean up the owning graph/function metadata properties - # and doc_string if the graph/function is not already checked - assert node.graph is not None - if node.graph not in checked_graphs_or_functions and ( - node.graph.metadata_props or node.graph.doc_string - ): - modified = True - logger.debug("Removed metadata from %s graph/function", node.graph.name) - node.graph.metadata_props.clear() - node.graph.doc_string = None - checked_graphs_or_functions.add(node.graph) - return modified diff --git a/onnxscript/ir/passes/common/clear_metadata_and_docstring_test.py b/onnxscript/ir/passes/common/clear_metadata_and_docstring_test.py deleted file mode 100644 index 7707a87ff6..0000000000 --- a/onnxscript/ir/passes/common/clear_metadata_and_docstring_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import numpy as np - -from onnxscript import ir -from onnxscript.ir.passes.common import clear_metadata_and_docstring - - -class TestClearMetadataAndDocStringPass(unittest.TestCase): - def test_pass_with_clear_metadata_and_docstring(self): - # Create a model (node, graph, function) with metadata and docstring - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ] - add_node = ir.node( - "Add", - inputs=inputs, - num_outputs=1, - metadata_props={"add_key": "add_value"}, - doc_string="This is an Add node", - ) - mul_node = ir.node( - "Mul", - inputs=[add_node.outputs[0], inputs[1]], - num_outputs=1, - metadata_props={"mul_key": "mul_value"}, - doc_string="This is a Mul node", - ) - func_inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ), - ] - function = ir.Function( - graph=ir.Graph( - name="my_function", - inputs=func_inputs, - outputs=mul_node.outputs, - nodes=[add_node, mul_node], - opset_imports={"": 20}, - doc_string="This is a function docstring", - metadata_props={"function_key": "function_value"}, - ), - name="my_function", - domain="my_domain", - attributes=[], - ) - # Create a model with the graph and function - constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) - const_node = ir.node( - "Constant", - inputs=[], - attributes={"value": constant_tensor}, - num_outputs=1, - metadata_props={"const_key": "const_value"}, - doc_string="This is a Constant node", - ) - sub_node = ir.node( - "Sub", - inputs=[function.outputs[0], const_node.outputs[0]], - num_outputs=1, - metadata_props={"sub_key": "sub_value"}, - doc_string="This is a Sub node", - ) - model = ir.Model( - graph=ir.Graph( - inputs=inputs, - outputs=sub_node.outputs, - nodes=[const_node, sub_node], - opset_imports={"": 20}, - doc_string="This is a graph docstring", - metadata_props={"graph_key": "graph_value"}, - ), - ir_version=10, - functions=[function], - ) - # Create a pass to clear metadata and docstring - clear_pass = clear_metadata_and_docstring.ClearMetadataAndDocStringPass() - # Apply the pass - result = clear_pass(model) - # Check that the pass was applied - self.assertTrue(result.modified) - # Check that the metadata and docstring were cleared - self.assertEqual(model.graph.doc_string, None) - self.assertEqual(model.graph.metadata_props, {}) - for node in model.graph: - self.assertEqual(node.metadata_props, {}) - self.assertEqual(node.doc_string, None) - # Check that the function docstring and metadata were cleared - self.assertEqual(function.doc_string, None) - self.assertEqual(function.metadata_props, {}) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/constant_manipulation.py b/onnxscript/ir/passes/common/constant_manipulation.py deleted file mode 100644 index b76c3c0802..0000000000 --- a/onnxscript/ir/passes/common/constant_manipulation.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Lift constants to initializers.""" - -from __future__ import annotations - -__all__ = [ - "AddInitializersToInputsPass", - "LiftConstantsToInitializersPass", - "LiftSubgraphInitializersToMainGraphPass", - "RemoveInitializersFromInputsPass", -] - -import logging - -import numpy as np - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -class LiftConstantsToInitializersPass(ir.passes.InPlacePass): - """Lift constants to initializers. - - Attributes: - lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.) - Default to False, where only Constants with the ``value`` attribute are lifted. - size_limit: The minimum size of the tensor to be lifted. If the tensor contains - number of elements less than size_limit, it will not be lifted. Default is 16. - """ - - def __init__(self, lift_all_constants: bool = False, size_limit: int = 16): - super().__init__() - self.lift_all_constants = lift_all_constants - self.size_limit = size_limit - - def call(self, model: ir.Model) -> ir.passes.PassResult: - count = 0 - for node in ir.traversal.RecursiveGraphIterator(model.graph): - assert node.graph is not None - if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"): - continue - if node.outputs[0].is_graph_output(): - logger.debug( - "Constant node '%s' is used as output, so it can't be lifted.", node.name - ) - continue - constant_node_attribute = set(node.attributes.keys()) - if len(constant_node_attribute) != 1: - logger.debug( - "Invalid constant node '%s' has more than one attribute", node.name - ) - continue - - attr_name, attr_value = next(iter(node.attributes.items())) - initializer_name = node.outputs[0].name - assert initializer_name is not None - assert isinstance(attr_value, ir.Attr) - tensor = self._constant_node_attribute_to_tensor( - node, attr_name, attr_value, initializer_name - ) - if tensor is None: - # The reason of None is logged in _constant_node_attribute_to_tensor - continue - # Register an initializer with the tensor value - initializer = ir.Value( - name=initializer_name, - shape=tensor.shape, # type: ignore[arg-type] - type=ir.TensorType(tensor.dtype), - const_value=tensor, - ) - assert node.graph is not None - node.graph.register_initializer(initializer) - # Replace the constant node with the initializer - ir.convenience.replace_all_uses_with(node.outputs[0], initializer) - node.graph.remove(node, safe=True) - count += 1 - logger.debug( - "Converted constant node '%s' to initializer '%s'", node.name, initializer_name - ) - if count: - logger.debug("Lifted %s constants to initializers", count) - return ir.passes.PassResult(model, modified=bool(count)) - - def _constant_node_attribute_to_tensor( - self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str - ) -> ir.TensorProtocol | None: - """Convert constant node attribute to tensor.""" - if not self.lift_all_constants and attr_name != "value": - logger.debug( - "Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name - ) - return None - - tensor: ir.TensorProtocol - if attr_name == "value": - tensor = attr_value.as_tensor() - elif attr_name == "value_int": - tensor = ir.tensor( - attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name - ) - elif attr_name == "value_ints": - tensor = ir.tensor( - attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name - ) - elif attr_name == "value_float": - tensor = ir.tensor( - attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name - ) - elif attr_name == "value_floats": - tensor = ir.tensor( - attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name - ) - elif attr_name in ("value_string", "value_strings"): - tensor = ir.StringTensor( - np.array(attr_value.value, dtype=np.bytes_), name=initializer_name - ) - else: - raise ValueError( - f"Unsupported constant node '{node.name}' attribute '{attr_name}'" - ) - - if tensor.size < self.size_limit: - logger.debug( - "Tensor from node '%s' has less than %s elements", - node.name, - self.size_limit, - ) - return None - return tensor - - -class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass): - """Lift subgraph initializers to main graph. - - This pass lifts the initializers of a subgraph to the main graph. - It is used to ensure that the initializers are available in the main graph - for further processing or optimization. - """ - - def call(self, model: ir.Model) -> ir.passes.PassResult: - count = 0 - registered_initializer_names: dict[str, int] = {} - for graph in model.graphs(): - if graph is model.graph: - continue - for name in tuple(graph.initializers): - initializer = graph.initializers[name] - if initializer.is_graph_input(): - # Skip the ones that are also graph inputs - logger.debug( - "Initializer '%s' is also a graph input, so it can't be lifted", - initializer.name, - ) - continue - # Remove the initializer from the subgraph - graph.initializers.pop(name) - # To avoid name conflicts, we need to rename the initializer - # to a unique name in the main graph - if name in registered_initializer_names: - name_count = registered_initializer_names[name] - initializer.name = f"{name}_{name_count}" - registered_initializer_names[name] = name_count + 1 - else: - assert initializer.name is not None - registered_initializer_names[initializer.name] = 1 - model.graph.register_initializer(initializer) - count += 1 - logger.debug( - "Lifted initializer '%s' from subgraph '%s' to main graph", - initializer.name, - graph.name, - ) - return ir.passes.PassResult(model, modified=bool(count)) - - -class RemoveInitializersFromInputsPass(ir.passes.InPlacePass): - """Remove initializers from inputs. - - This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list. - """ - - def call(self, model: ir.Model) -> ir.passes.PassResult: - count = 0 - for graph in model.graphs(): - initializers = set(graph.initializers.values()) - new_inputs = [] - for input_value in graph.inputs: - if input_value in initializers: - count += 1 - else: - new_inputs.append(input_value) - graph.inputs.clear() - graph.inputs.extend(new_inputs) - logger.info("Removed %s initializers from graph inputs", count) - return ir.passes.PassResult(model, modified=bool(count)) - - -class AddInitializersToInputsPass(ir.passes.InPlacePass): - """Add initializers to inputs. - - This pass finds all initializers and adds them to the graph.inputs list if they are not already present. - """ - - def call(self, model: ir.Model) -> ir.passes.PassResult: - count = 0 - for graph in model.graphs(): - inputs_set = set(graph.inputs) - for initializer in graph.initializers.values(): - if initializer not in inputs_set: - graph.inputs.append(initializer) - count += 1 - logger.info("Added %s initializers to graph inputs", count) - return ir.passes.PassResult(model, modified=bool(count)) diff --git a/onnxscript/ir/passes/common/constant_manipulation_test.py b/onnxscript/ir/passes/common/constant_manipulation_test.py deleted file mode 100644 index d02933136b..0000000000 --- a/onnxscript/ir/passes/common/constant_manipulation_test.py +++ /dev/null @@ -1,516 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import numpy as np -import parameterized - -from onnxscript import ir -from onnxscript.ir.passes.common import constant_manipulation - - -class TestLiftConstantsToInitializersPass(unittest.TestCase): - @parameterized.parameterized.expand( - [ - (ir.DataType.FLOAT, True), - (ir.DataType.FLOAT, False), - (ir.DataType.INT64, True), - (ir.DataType.INT64, False), - ] - ) - def test_pass_with_lifting_float_and_int_constants_to_initializers( - self, ir_dtype: ir.DataType, lift_all_constants: bool - ): - inputs = [ - ir.Value(name="input_a", type=ir.TensorType(ir_dtype), shape=ir.Shape((2, 3))), - ir.Value( - name="input_b", - type=ir.TensorType(ir_dtype), - shape=ir.Shape((2, 3)), - ), - ] - - constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir_dtype.numpy())) - const_node = ir.node( - "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 - ) - add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) - mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) - - model = ir.Model( - graph=ir.Graph( - inputs=inputs, - outputs=mul_node.outputs, - nodes=[const_node, add_node, mul_node], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Check that the initializer is not in the graph yet - self.assertEqual(len(model.graph.initializers), 0) - # And 1 constant node - self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) - - # Perform lift constants to initializers - result = constant_manipulation.LiftConstantsToInitializersPass( - lift_all_constants=lift_all_constants, size_limit=0 - )(model) - self.assertTrue(result.modified) - # Check that the constant node is lifted to an initializer - self.assertEqual(len(result.model.graph.initializers), 1) - # Check the value - self.assertEqual( - result.model.graph.initializers[ - "val_0" - ].const_value, # name created by name_authority - constant_tensor, - ) - # And 0 constant node - self.assertEqual( - len([node for node in result.model.graph if node.op_type == "Constant"]), 0 - ) - - @parameterized.parameterized.expand( - [ - (True,), - (False,), - ] - ) - def test_pass_with_lifting_constants_to_initializers_within_subgraph( - self, lift_all_constants: bool - ): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - - then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - then_const_node = ir.node( - "Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 - ) - # then branch adds the constant to the input - # else branch multiplies the input by the constant - add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) - then_graph = ir.Graph( - inputs=[], - outputs=[add_node.outputs[0]], - nodes=[then_const_node, add_node], - opset_imports={"": 20}, - ) - else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - else_const_node = ir.node( - "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 - ) - mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) - else_graph = ir.Graph( - inputs=[], - outputs=[mul_node.outputs[0]], - nodes=[else_const_node, mul_node], - opset_imports={"": 20}, - ) - # Create a conditional node that uses the then and else graphs - cond_node = ir.node( - "If", - inputs=[input_value], - attributes={"then_branch": then_graph, "else_branch": else_graph}, - num_outputs=1, - ) - # Construct the model - main_graph = ir.Graph( - inputs=[input_value], - outputs=cond_node.outputs, - nodes=[cond_node], - opset_imports={"": 20}, - ) - main_graph.sort() - model = ir.Model( - graph=main_graph, - ir_version=10, - ) - result = constant_manipulation.LiftConstantsToInitializersPass( - lift_all_constants=lift_all_constants, size_limit=0 - )(model) - self.assertTrue(result.modified) - # Check that the constant node is lifted to the subgraph initializers - for node in ir.traversal.RecursiveGraphIterator(result.model.graph): - if node.op_type == "Constant": - raise AssertionError( - f"Constant node '{node.name}' was not lifted to initializers" - ) - self.assertEqual(len(else_graph.initializers), 1) - self.assertEqual(len(then_graph.initializers), 1) - self.assertIs(else_graph.initializers["val_0"].const_value, else_constant_tensor) - self.assertIs(then_graph.initializers["val_0"].const_value, then_constant_tensor) - - @parameterized.parameterized.expand( - [ - (1.0, "value_float", np.float32, True), - (1.0, "value_float", np.float32, False), - (1, "value_int", np.int64, True), - (1, "value_int", np.int64, False), - ("hello world!", "value_string", np.bytes_, True), - ("hello world!", "value_string", np.bytes_, False), - ([1.0, 2.0, 3.0], "value_floats", np.float32, True), - ([1.0, 2.0, 3.0], "value_floats", np.float32, False), - ([1, 2, 3], "value_ints", np.int64, True), - ([1, 2, 3], "value_ints", np.int64, False), - (["hello world!", "thank you."], "value_strings", np.bytes_, True), - (["hello world!", "thank you."], "value_strings", np.bytes_, False), - ] - ) - def test_pass_with_lifting_constants_to_initializers_with_floats_ints_strings( - self, - value: float | int | str | list[float] | list[int] | list[str], - constant_attribute: str, - np_dtype: type[np.dtype], - lift_all_constants: bool, - ): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - - constant_value = value - const_node = ir.node( - "Constant", - inputs=[], - attributes={constant_attribute: constant_value}, - num_outputs=1, - ) - identity_node_constant = ir.node( - "Identity", inputs=[const_node.outputs[0]], num_outputs=1 - ) - identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value], - outputs=[identity_node_input.outputs[0], identity_node_constant.outputs[0]], - nodes=[identity_node_input, const_node, identity_node_constant], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Check that the initializer is not in the graph yet - self.assertEqual(len(model.graph.initializers), 0) - # And 1 constant node - self.assertEqual(len([node for node in model.graph if node.op_type == "Constant"]), 1) - - # Perform lift constants to initializers - result = constant_manipulation.LiftConstantsToInitializersPass( - lift_all_constants=lift_all_constants, size_limit=0 - )(model) - if lift_all_constants: - self.assertTrue(result.modified) - # Check that the constant node is lifted to an initializer - self.assertEqual(len(result.model.graph.initializers), 1) - np.testing.assert_array_equal( - result.model.graph.initializers["val_1"].const_value.numpy(), - np.array(constant_value, dtype=np_dtype), - ) - else: - self.assertFalse(result.modified) - # Check that the constant node is not lifted to an initializer - self.assertEqual(len(result.model.graph.initializers), 0) - - def test_not_lifting_constants_to_initializers_when_it_is_output(self): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - identity_node_input = ir.node("Identity", inputs=[input_value], num_outputs=1) - - constant_value = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - const_node = ir.node( - "Constant", - inputs=[], - attributes={"value": constant_value}, - num_outputs=1, - ) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value], - outputs=[identity_node_input.outputs[0], const_node.outputs[0]], - nodes=[identity_node_input, const_node], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - result = constant_manipulation.LiftConstantsToInitializersPass()(model) - self.assertFalse(result.modified) - # Check that the constant node is not lifted to an initializer - self.assertEqual(len(result.model.graph.initializers), 0) - - -class TestLiftSubgraphInitializersToMainGraphPass(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("then_initializer", "else_initializer"), - ("initializer", "initializer"), - ] - ) - def test_pass_with_lifting_constants_to_initializers_within_subgraph( - self, then_initializer_name, else_initializer_name - ): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - - then_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - then_initializer_value = ir.Value( - name=then_initializer_name, - shape=then_initializer_tensor.shape, - type=ir.TensorType(ir.DataType.FLOAT), - const_value=then_initializer_tensor, - ) - - # then branch adds the constant to the input - # else branch multiplies the input by the constant - add_node = ir.node("Add", inputs=[input_value, then_initializer_value]) - then_graph = ir.Graph( - inputs=[], - outputs=[add_node.outputs[0]], - nodes=[add_node], - opset_imports={"": 20}, - initializers=[then_initializer_value], - ) - else_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - else_initializer_value = ir.Value( - name=else_initializer_name, - shape=else_initializer_tensor.shape, - type=ir.TensorType(ir.DataType.FLOAT), - const_value=else_initializer_tensor, - ) - mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) - else_graph = ir.Graph( - inputs=[], - outputs=[mul_node.outputs[0]], - nodes=[mul_node], - opset_imports={"": 20}, - initializers=[else_initializer_value], - ) - # Create a conditional node that uses the then and else graphs - cond_node = ir.node( - "If", - inputs=[input_value], - attributes={"then_branch": then_graph, "else_branch": else_graph}, - num_outputs=1, - ) - # Construct the model - main_graph = ir.Graph( - inputs=[input_value], - outputs=cond_node.outputs, - nodes=[cond_node], - opset_imports={"": 20}, - ) - main_graph.sort() - model = ir.Model( - graph=main_graph, - ir_version=10, - ) - result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) - self.assertTrue(result.modified) - - self.assertEqual(len(else_graph.initializers), 0) - self.assertEqual(len(then_graph.initializers), 0) - self.assertEqual(len(main_graph.initializers), 2) - for value, tensor in zip( - main_graph.initializers.values(), - [then_initializer_tensor, else_initializer_tensor], - ): - self.assertIs(value.const_value, tensor) - - @parameterized.parameterized.expand( - [ - ("then_initializer", "else_initializer"), - ("initializer", "initializer"), - ] - ) - def test_pass_does_not_lift_initialized_inputs_in_subgraph( - self, then_initializer_name, else_initializer_name - ): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - - then_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - then_initializer_value = ir.Value( - name=then_initializer_name, - shape=then_initializer_tensor.shape, - type=ir.TensorType(ir.DataType.FLOAT), - const_value=then_initializer_tensor, - ) - - # then branch adds the constant to the input - # else branch multiplies the input by the constant - add_node = ir.node("Add", inputs=[input_value, then_initializer_value]) - then_graph = ir.Graph( - # The initializer is also an input. We don't lift it to the main graph - # to preserve the graph signature - inputs=[then_initializer_value], - outputs=[add_node.outputs[0]], - nodes=[add_node], - opset_imports={"": 20}, - initializers=[then_initializer_value], - ) - else_initializer_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) - else_initializer_value = ir.Value( - name=else_initializer_name, - shape=else_initializer_tensor.shape, - type=ir.TensorType(ir.DataType.FLOAT), - const_value=else_initializer_tensor, - ) - mul_node = ir.node("Mul", inputs=[input_value, else_initializer_value]) - else_graph = ir.Graph( - inputs=[], - outputs=[mul_node.outputs[0]], - nodes=[mul_node], - opset_imports={"": 20}, - initializers=[else_initializer_value], - ) - # Create a conditional node that uses the then and else graphs - cond_node = ir.node( - "If", - inputs=[input_value], - attributes={"then_branch": then_graph, "else_branch": else_graph}, - num_outputs=1, - ) - # Construct the model - main_graph = ir.Graph( - inputs=[input_value], - outputs=cond_node.outputs, - nodes=[cond_node], - opset_imports={"": 20}, - ) - main_graph.sort() - model = ir.Model( - graph=main_graph, - ir_version=10, - ) - result = constant_manipulation.LiftSubgraphInitializersToMainGraphPass()(model) - self.assertTrue(result.modified) - - self.assertEqual(len(else_graph.initializers), 0) - self.assertEqual(len(then_graph.initializers), 1) - self.assertEqual(len(main_graph.initializers), 1) - for value, tensor in zip(main_graph.initializers.values(), [else_initializer_tensor]): - self.assertIs(value.const_value, tensor) - - -class TestRemoveInitializersFromInputsPass(unittest.TestCase): - def test_remove_initializers_from_inputs(self): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - initializer_value = ir.Value( - name="initializer", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape((2, 3)), - const_value=ir.tensor(np.random.rand(2, 3).astype(np.float32)), - ) - identity_node = ir.node("Identity", inputs=[input_value], num_outputs=1) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value, initializer_value], - outputs=identity_node.outputs, - nodes=[identity_node], - initializers=[initializer_value], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Check that the initializer is in the graph inputs - self.assertIn(initializer_value, model.graph.inputs) - - # Perform remove initializers from inputs - result = constant_manipulation.RemoveInitializersFromInputsPass()(model) - self.assertTrue(result.modified) - # Check that the initializer is removed from the graph inputs - self.assertNotIn(initializer_value, result.model.graph.inputs) - - def test_remove_initializers_from_inputs_with_no_initializers(self): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - identity_node = ir.node("Identity", inputs=[input_value], num_outputs=1) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value], - outputs=identity_node.outputs, - nodes=[identity_node], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Perform remove initializers from inputs - result = constant_manipulation.RemoveInitializersFromInputsPass()(model) - self.assertFalse(result.modified) - # Check that the graph inputs remain unchanged - self.assertEqual(result.model.graph.inputs, [input_value]) - - -class TestAddInitializersToInputsPass(unittest.TestCase): - def test_add_initializers_to_inputs(self): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - initializer_value = ir.Value( - name="initializer", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape((2, 3)), - const_value=ir.tensor(np.random.rand(2, 3).astype(np.float32)), - ) - identity_node = ir.node("Identity", inputs=[input_value], num_outputs=1) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value], - outputs=identity_node.outputs, - nodes=[identity_node], - initializers=[initializer_value], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Check that the initializer is not in the graph inputs - self.assertNotIn(initializer_value, model.graph.inputs) - - # Perform add initializers to inputs - result = constant_manipulation.AddInitializersToInputsPass()(model) - self.assertTrue(result.modified) - # Check that the initializer is added to the graph inputs - self.assertIn(initializer_value, result.model.graph.inputs) - - def test_add_initializers_to_inputs_with_no_initializers(self): - input_value = ir.Value( - name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) - ) - identity_node = ir.node("Identity", inputs=[input_value], num_outputs=1) - - model = ir.Model( - graph=ir.Graph( - inputs=[input_value], - outputs=identity_node.outputs, - nodes=[identity_node], - opset_imports={"": 20}, - ), - ir_version=10, - ) - - # Perform add initializers to inputs - result = constant_manipulation.AddInitializersToInputsPass()(model) - self.assertFalse(result.modified) - # Check that the graph inputs remain unchanged - self.assertEqual(result.model.graph.inputs, [input_value]) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py deleted file mode 100644 index 3a4f97a8a7..0000000000 --- a/onnxscript/ir/passes/common/inliner.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Implementation of an inliner for onnxscript.ir""" - -from __future__ import annotations - -import dataclasses - -__all__ = ["InlinePass", "InlinePassResult"] - -from collections import defaultdict -from typing import Iterable, List, Sequence, Tuple - -import onnxscript.ir.convenience as _ir_convenience -from onnxscript import ir - -# A replacement for a node specifies a list of nodes that replaces the original node, -# and a list of values that replaces the original node's outputs. - -NodeReplacement = Tuple[Sequence[ir.Node], Sequence[ir.Value]] - -# A call stack is a list of identifiers of call sites, where the first element is the -# outermost call site, and the last element is the innermost call site. This is used -# primarily for generating unique names for values in the inlined functions. -CallSiteId = str -CallStack = List[CallSiteId] - - -def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument - """Generate a unique name from a name, calling-context, and set of used names. - - If there is a name clash, we add a numeric suffix to the name to make - it unique. We use the same strategy to make node names unique. - - TODO: We can use the callstack in generating a name for a value X in a function - that is inlined into a graph. This is not yet implemented. Using the full callstack - leads to very long and hard to read names. Some investigation is needed to find - a good naming strategy that will produce useful names for debugging. - """ - candidate = name - i = 1 - while candidate in used_names: - i += 1 - candidate = f"{name}_{i}" - used_names.add(candidate) - return candidate - - -class _CopyReplace: - """Utilities for creating a copy of IR objects with substitutions for attributes/input values.""" - - def __init__( - self, - inliner: InlinePass, - attr_map: dict[str, ir.Attr | ir.RefAttr], - value_map: dict[ir.Value, ir.Value | None], - metadata_props: dict[str, str], - call_stack: CallStack, - ) -> None: - self._inliner = inliner - self._value_map = value_map - self._attr_map = attr_map - self._metadata_props = metadata_props - self._call_stack = call_stack - - def clone_value(self, value: ir.Value) -> ir.Value | None: - if value in self._value_map: - return self._value_map[value] - # If the value is not in the value map, it must be a graph input. - assert value.producer() is None, f"Value {value} has no entry in the value map" - new_value = ir.Value( - name=value.name, - type=value.type, - shape=value.shape, - doc_string=value.doc_string, - const_value=value.const_value, - ) - self._value_map[value] = new_value - return new_value - - def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None: - if value is None: - return None - return self.clone_value(value) - - def clone_attr(self, key: str, attr: ir.Attr | ir.RefAttr) -> ir.Attr | ir.RefAttr | None: - if isinstance(attr, ir.Attr): - if attr.type == ir.AttributeType.GRAPH: - graph = self.clone_graph(attr.as_graph()) - return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string) - elif attr.type == ir.AttributeType.GRAPHS: - graphs = [self.clone_graph(graph) for graph in attr.as_graphs()] - return ir.Attr( - key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string - ) - return attr - assert isinstance(attr, ir.RefAttr) - ref_attr_name = attr.ref_attr_name - if ref_attr_name in self._attr_map: - ref_attr = self._attr_map[ref_attr_name] - if isinstance(ref_attr, ir.Attr): - return ir.Attr( - key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string - ) - assert isinstance(ref_attr, ir.RefAttr) - return ir.RefAttr( - key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string - ) - # Note that if a function has an attribute-parameter X, and a call (node) to the function - # has no attribute X, all references to X in nodes inside the function body will be - # removed. This is just the ONNX representation of optional-attributes. - return None - - def clone_node(self, node: ir.Node) -> ir.Node: - new_inputs = [self.clone_optional_value(input) for input in node.inputs] - new_attributes = [ - new_value - for key, value in node.attributes.items() - if (new_value := self.clone_attr(key, value)) is not None - ] - new_name = node.name - if new_name is not None: - new_name = _make_unique_name( - new_name, self._call_stack, self._inliner.used_node_names - ) - - new_metadata = {**self._metadata_props, **node.metadata_props} - # TODO: For now, node metadata overrides callnode metadata if there is a conflict. - # Do we need to preserve both? - - new_node = ir.Node( - node.domain, - node.op_type, - new_inputs, - new_attributes, - overload=node.overload, - num_outputs=len(node.outputs), - graph=None, - name=new_name, - doc_string=node.doc_string, # type: ignore - metadata_props=new_metadata, - ) - new_outputs = new_node.outputs - for i, output in enumerate(node.outputs): - self._value_map[output] = new_outputs[i] - old_name = output.name if output.name is not None else f"output_{i}" - new_outputs[i].name = _make_unique_name( - old_name, self._call_stack, self._inliner.used_value_names - ) - - self._inliner.node_context[new_node] = self._call_stack - - return new_node - - def clone_graph(self, graph: ir.Graph) -> ir.Graph: - input_values = [self.clone_value(v) for v in graph.inputs] - nodes = [self.clone_node(node) for node in graph] - initializers = [self.clone_value(init) for init in graph.initializers.values()] - output_values = [ - self.clone_value(v) for v in graph.outputs - ] # Looks up already cloned values - - return ir.Graph( - input_values, # type: ignore - output_values, # type: ignore - nodes=nodes, - initializers=initializers, # type: ignore - doc_string=graph.doc_string, - opset_imports=graph.opset_imports, - name=graph.name, - metadata_props=graph.metadata_props, - ) - - -def _abbreviate( - function_ids: Iterable[ir.OperatorIdentifier], -) -> dict[ir.OperatorIdentifier, str]: - """Create a short unambiguous abbreviation for all function ids.""" - - def id_abbreviation(id: ir.OperatorIdentifier) -> str: - """Create a short unambiguous abbreviation for a function id.""" - domain, name, overload = id - # Omit the domain, if it remains unambiguous after omitting it. - if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids): - short_domain = domain + "_" - else: - short_domain = "" - if overload != "": - return short_domain + name + "_" + overload - return short_domain + name - - return {id: id_abbreviation(id) for id in function_ids} - - -@dataclasses.dataclass -class InlinePassResult(ir.passes.PassResult): - id_count: dict[ir.OperatorIdentifier, int] - - -class InlinePass(ir.passes.InPlacePass): - """Inline model local functions to the main graph and clear function definitions.""" - - def __init__(self) -> None: - super().__init__() - self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} - self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {} - self._opset_imports: dict[str, int] = {} - self.used_value_names: set[str] = set() - self.used_node_names: set[str] = set() - self.node_context: dict[ir.Node, CallStack] = {} - - def _reset(self, model: ir.Model) -> None: - self._functions = model.functions - self._function_id_abbreviations = _abbreviate(self._functions.keys()) - self._opset_imports = model.opset_imports - self.used_value_names = set() - self.used_node_names = set() - self.node_context = {} - - def call(self, model: ir.Model) -> InlinePassResult: - self._reset(model) - id_count = self._inline_calls_in(model.graph) - model.functions.clear() - return InlinePassResult(model, modified=bool(id_count), id_count=id_count) - - def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement: - id = node.op_identifier() - function = self._functions[id] - - # check opset compatibility and update the opset imports - for key, value in function.opset_imports.items(): - if key not in self._opset_imports: - self._opset_imports[key] = value - elif self._opset_imports[key] != value: - raise ValueError( - f"Opset mismatch: {key} {self._opset_imports[key]} != {value}" - ) - - # Identify substitutions for both inputs and attributes of the function: - attributes: dict[str, ir.Attr | ir.RefAttr] = node.attributes - default_attr_values = { - attr.name: attr - for attr in function.attributes.values() - if attr.name not in attributes and attr.value is not None - } - if default_attr_values: - attributes = {**attributes, **default_attr_values} - if any( - attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} - for attr in attributes.values() - ): - raise ValueError( - "Inliner does not support graph attribute parameters to functions" - ) - - if len(node.inputs) > len(function.inputs): - raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}") - value_map = {} - for i, input in enumerate(node.inputs): - value_map[function.inputs[i]] = input - for i in range(len(node.inputs), len(function.inputs)): - value_map[function.inputs[i]] = None - - # Identify call-stack for node, used to generate unique names. - call_stack = self.node_context.get(node, []) - new_call_stack = [*call_stack, call_site_id] - - cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack) - - # iterate over the nodes in the function, creating a copy of each node - # and replacing inputs with the corresponding values in the value map. - # Update the value map with the new values. - - nodes = [cloner.clone_node(node) for node in function] - output_values = [value_map[output] for output in function.outputs] - return nodes, output_values # type: ignore - - def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]: - for input in graph.inputs: - if input.name is not None: - self.used_value_names.add(input.name) - for initializer in graph.initializers: - self.used_value_names.add(initializer) - - # Pre-processing: - # * Count the number of times each function is called in the graph. - # This is used for disambiguating names of values in the inlined functions. - # * And identify names of values that are used in the graph. - id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int) - for node in graph: - if node.name: - self.used_node_names.add(node.name) - id = node.op_identifier() - if id in self._functions: - id_count[id] += 1 - for output in node.outputs: - if output.name is not None: - self.used_value_names.add(output.name) - next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int) - for node in graph: - id = node.op_identifier() - if id in self._functions: - # If there are multiple calls to same function, we use a prefix to disambiguate - # the different call-sites: - if id_count[id] > 1: - call_site_prefix = f"_{next_id[id]}" - next_id[id] += 1 - else: - call_site_prefix = "" - call_site = node.name or ( - self._function_id_abbreviations[id] + call_site_prefix - ) - nodes, values = self._instantiate_call(node, call_site) - _ir_convenience.replace_nodes_and_values( - graph, - insertion_point=node, - old_nodes=[node], - new_nodes=nodes, - old_values=node.outputs, - new_values=values, - ) - else: - for attr in node.attributes.values(): - if not isinstance(attr, ir.Attr): - continue - if attr.type == ir.AttributeType.GRAPH: - self._inline_calls_in(attr.as_graph()) - elif attr.type == ir.AttributeType.GRAPHS: - for g in attr.as_graphs(): - self._inline_calls_in(g) - return id_count diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py deleted file mode 100644 index 1a4be6ce8e..0000000000 --- a/onnxscript/ir/passes/common/inliner_test.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Tests for the inliner pass.""" - -from __future__ import annotations - -import unittest -from typing import Callable, Sequence - -import onnx - -from onnxscript import ir -from onnxscript.ir.passes.common import inliner - - -def _name_checker(renameable: Sequence[str] | None) -> Callable[[str, str], bool]: - """Construct function to check if actual value name matches expected value name. - - This is used to avoid hard-coding the expected names in the test cases. - """ - # Default to exact match if no renaming is allowed. - if renameable is None: - return lambda a, b: a == b - # If some names are allowed to be renamed, keep track of the renaming. - # And check that the renaming is consistent across all nodes. - renaming_map: dict[str, str] = {} - - def check(actual: str, expected: str) -> bool: - if expected in renameable: - # actual name can be different, as long as it is consistently used. - if expected in renaming_map: - return renaming_map[expected] == actual - renaming_map[expected] = actual - return True - else: - return actual == expected - - return check - - -class InlinerTest(unittest.TestCase): - def _check( - self, input_model: str, expected_model: str, renameable: Sequence[str] | None = None - ) -> None: - name_check = _name_checker(renameable) - model_ir = ir.from_onnx_text(input_model) - inliner.InlinePass()(model_ir) - proto = ir.serde.serialize_model(model_ir) - text = onnx.printer.to_text(proto) - print(text) - expected_ir = ir.from_onnx_text(expected_model) - self.assertEqual(len(model_ir.graph), len(expected_ir.graph)) - for node, expected_node in zip(model_ir.graph, expected_ir.graph): - # TODO: handle node renaming - self.assertEqual(node.op_type, expected_node.op_type) - self.assertEqual(len(node.inputs), len(expected_node.inputs)) - for input, expected_input in zip(node.inputs, expected_node.inputs): - self.assertEqual(input is None, expected_input is None) - if input is not None: - self.assertTrue(name_check(input.name, expected_input.name)) - self.assertEqual(len(node.attributes), len(expected_node.attributes)) - for key, value in node.attributes.items(): - self.assertIn(key, expected_node.attributes) - expected_value = expected_node.attributes[key] - self.assertTrue(isinstance(value, ir.Attr)) - self.assertTrue(isinstance(expected_value, ir.Attr)) - self.assertEqual(value.type, expected_value.type) - if value.type not in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): - self.assertEqual(value.value, expected_value.value) - else: - self.fail("Graph attributes are not supported yet") - # TODO: handle graph attributes - self.assertEqual(len(node.outputs), len(expected_node.outputs)) - for output, expected_output in zip(node.outputs, expected_node.outputs): - self.assertTrue(name_check(output.name, expected_output.name)) - - def test_single_call(self): - input_model = """ - - agraph (float[N] X) => (float[N] Y) - { - Y = local.foo (X) - } - - - foo (x) => (y) { - temp = Add(x, x) - y = Mul(temp, temp) - } - """ - expected_model = """ - - agraph (float[N] X) => (float[N] Y) - { - temp = Add(X, X) - Y = Mul(temp, temp) - } - """ - self._check(input_model, expected_model, renameable=["temp"]) - - def test_two_calls(self): - input_model = """ - - agraph (float[N] X) => (float[N] Y) - { - T = local.foo (X) - Y = local.foo (T) - } - - - foo (x) => (y) { - temp = Add(x, x) - y = Mul(temp, temp) - } - """ - expected_model = """ - - agraph (float[N] X) => (float[N] Y) - { - temp1 = Add(X, X) - T = Mul(temp1, temp1) - temp2 = Add(T, T) - Y = Mul(temp2, temp2) - } - """ - self._check(input_model, expected_model, renameable=["temp1", "temp2"]) - - def test_nested_call(self): - input_model = """ - - agraph (float[N] X) => (float[N] Y) - { - Y = local.foo (X) - } - - - foo (x) => (y) { - temp = Add(x, x) - y = local.bar(temp) - } - - - bar (x) => (y) { - y = Mul (x, x) - } - """ - expected_model = """ - - agraph (float[N] X) => (float[N] Y) - { - temp = Add(X, X) - Y = Mul(temp, temp) - } - """ - self._check(input_model, expected_model, renameable=["temp"]) - - def test_attr_parameter(self): - input_model = """ - - agraph (float[N] X) => (float[N] Y) - { - Y = local.foo (X) - } - - - foo (x) => (y) { - y = Selu (x) - } - """ - expected_model = """ - - agraph (float[N] X) => (float[N] Y) - { - Y = Selu (X) - } - """ - self._check(input_model, expected_model) - - def test_attr_parameter_with_default_value(self): - input_model = """ - - agraph (float[N] X) => (float[N] Y) - { - T = local.foo (X) - Y = local.foo (T) - } - - - foo (x) => (y) { - y = Selu (x) - } - """ - expected_model = """ - - agraph (float[N] X) => (float[N] Y) - { - T = Selu (X) - Y = Selu (T) - } - """ - self._check(input_model, expected_model) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/onnx_checker.py b/onnxscript/ir/passes/common/onnx_checker.py deleted file mode 100644 index b815629641..0000000000 --- a/onnxscript/ir/passes/common/onnx_checker.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Passes for debugging purposes.""" - -from __future__ import annotations - -__all__ = [ - "CheckerPass", -] - -from typing import Literal - -import onnx - -from onnxscript import ir -from onnxscript.ir.passes.common import _c_api_utils - - -class CheckerPass(ir.passes.PassBase): - """Run onnx checker on the model.""" - - @property - def in_place(self) -> Literal[True]: - """This pass does not create a new model.""" - return True - - @property - def changes_input(self) -> Literal[False]: - """This pass does not change the input model.""" - return False - - def __init__( - self, - full_check: bool = False, - skip_opset_compatibility_check: bool = False, - check_custom_domain: bool = False, - ): - super().__init__() - self.full_check = full_check - self.skip_opset_compatibility_check = skip_opset_compatibility_check - self.check_custom_domain = check_custom_domain - - def call(self, model: ir.Model) -> ir.passes.PassResult: - """Run the onnx checker on the model.""" - - def _partial_check_model(proto: onnx.ModelProto) -> None: - """Partial function to check the model.""" - onnx.checker.check_model( - proto, - full_check=self.full_check, - skip_opset_compatibility_check=self.skip_opset_compatibility_check, - check_custom_domain=self.check_custom_domain, - ) - - _c_api_utils.call_onnx_api(func=_partial_check_model, model=model) - # The model is not modified - return ir.passes.PassResult(model, False) diff --git a/onnxscript/ir/passes/common/onnx_checker_test.py b/onnxscript/ir/passes/common/onnx_checker_test.py deleted file mode 100644 index 144225416d..0000000000 --- a/onnxscript/ir/passes/common/onnx_checker_test.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -from onnxscript import ir -from onnxscript.ir.passes.common import onnx_checker - - -class TestCheckerPass(unittest.TestCase): - def test_pass_is_no_op(self): - checker_pass = onnx_checker.CheckerPass() - self.assertTrue(checker_pass.in_place) - self.assertFalse(checker_pass.changes_input) - - def test_check_simple_model(self): - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ] - - tape = ir.tape.Tape() - - output = tape.op("Add", inputs=inputs) - output.shape = ir.Shape((1, 2)) - output.dtype = ir.DataType.FLOAT - - model = ir.Model( - ir.Graph( - inputs=inputs, - outputs=[output], - nodes=tape.nodes, - opset_imports={"": 20}, - name="test_model", - ), - ir_version=10, - ) - # No exception should be raised - onnx_checker.CheckerPass()(model) - - def test_check_invalid_model(self): - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ] - - tape = ir.tape.Tape() - - output = tape.op("Add", inputs=inputs) - output.shape = ir.Shape((1, 2)) - output.dtype = ir.DataType.FLOAT - - model = ir.Model( - ir.Graph( - inputs=inputs, - outputs=[output], - nodes=tape.nodes, - opset_imports={"": 20}, - ), - ir_version=10, - ) - - with self.assertRaisesRegex( - Exception, "Field 'name' of 'graph' is required to be non-empty" - ): - onnx_checker.CheckerPass()(model) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py deleted file mode 100644 index 586fa5b417..0000000000 --- a/onnxscript/ir/passes/common/shape_inference.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Shape inference pass using onnx.shape_inference.""" - -from __future__ import annotations - -__all__ = [ - "ShapeInferencePass", - "infer_shapes", -] - -import logging - -import onnx - -from onnxscript import ir -from onnxscript.ir.passes.common import _c_api_utils - -logger = logging.getLogger(__name__) - - -def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool: - """Merge the shape inferred model with the original model. - - Args: - model: The original IR model. - inferred_proto: The ONNX model with shapes and types inferred. - - Returns: - A tuple containing the modified model and a boolean indicating whether the model was modified. - """ - inferred_model = ir.serde.deserialize_model(inferred_proto) - modified = False - for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()): - original_values = ir.convenience.create_value_mapping(original_graph) - inferred_values = ir.convenience.create_value_mapping(inferred_graph) - for name, value in original_values.items(): - if name in inferred_values: - inferred_value = inferred_values[name] - if value.shape != inferred_value.shape and inferred_value.shape is not None: - value.shape = inferred_value.shape - modified = True - if value.dtype != inferred_value.dtype and inferred_value.dtype is not None: - value.dtype = inferred_value.dtype - modified = True - else: - logger.warning( - "Value %s not found in inferred graph %s", name, inferred_graph.name - ) - return modified - - -class ShapeInferencePass(ir.passes.InPlacePass): - """This pass performs shape inference on the graph.""" - - def __init__( - self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True - ) -> None: - """Initialize the shape inference pass. - - If inference fails, the model is left unchanged. - - Args: - check_type: If True, check the types of the inputs and outputs. - strict_mode: If True, use strict mode for shape inference. - data_prop: If True, use data propagation for shape inference. - """ - super().__init__() - self.check_type = check_type - self.strict_mode = strict_mode - self.data_prop = data_prop - - def call(self, model: ir.Model) -> ir.passes.PassResult: - def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: - return onnx.shape_inference.infer_shapes( - proto, - check_type=self.check_type, - strict_mode=self.strict_mode, - data_prop=self.data_prop, - ) - - try: - inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) - except Exception as e: # pylint: disable=broad-exception-caught - logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) - return ir.passes.PassResult(model, False) - - modified = _merge_func(model, inferred_model_proto) - return ir.passes.PassResult(model, modified=modified) - - -def infer_shapes( - model: ir.Model, - *, - check_type: bool = True, - strict_mode: bool = True, - data_prop: bool = True, -) -> ir.Model: - """Perform shape inference on the model. - - Args: - model: The model to perform shape inference on. - check_type: If True, check the types of the inputs and outputs. - strict_mode: If True, use strict mode for shape inference. - data_prop: If True, use data propagation for shape inference. - - Returns: - The model with shape inference applied. - """ - return ShapeInferencePass( - check_type=check_type, strict_mode=strict_mode, data_prop=data_prop - )(model).model diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py deleted file mode 100644 index 5a2f02c64e..0000000000 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import numpy as np - -from onnxscript import ir -from onnxscript.ir.passes.common import _c_api_utils, shape_inference - - -class TestShapeInferencePass(unittest.TestCase): - def test_pass_is_in_place(self): - self.assertTrue(shape_inference.ShapeInferencePass().in_place) - - def test_pass(self): - # Create a simple ONNX model with shape inference - # Define the model - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ] - - tape = ir.tape.Tape() - - output = tape.op("Add", inputs=inputs) - - model = ir.Model( - ir.Graph( - inputs=inputs, - outputs=[output], - nodes=tape.nodes, - opset_imports={"": 20}, - ), - ir_version=10, - ) - self.assertIsNone(output.shape) - self.assertIsNone(output.dtype) - - # Perform shape inference - result = shape_inference.ShapeInferencePass()(model) - self.assertTrue(result.modified) - self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((1, 2))) - self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT) - self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((1, 2))) - self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT) - - def test_pass_with_initializers(self): - # _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size - # of a tensor. This is fine as we just need to create a big tensor whose size - # passes _BIG_TENSOR_SIZE_LIMIT - big_dim = _c_api_utils._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access - inputs = [ - ir.Value( - name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) - ), - ir.Value( - name="input_b", - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape((big_dim, 1)), - const_value=ir.tensor([[42]] * big_dim, dtype=ir.DataType.FLOAT), - ), - ] - - tape = ir.tape.Tape() - - # Shape and type are not explicitly set for the initializer but it should still work - initializer = ir.Value( - name="initializer", const_value=ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT) - ) - val_add = tape.op("Add", inputs=inputs) - val_mul = tape.op("Mul", inputs=[val_add, initializer]) - - model = ir.Model( - ir.Graph( - inputs=inputs, - outputs=[val_mul], - nodes=tape.nodes, - opset_imports={"": 20}, - initializers=[inputs[1], initializer], - ), - ir_version=10, - ) - - self.assertIsNone(val_add.shape) - self.assertIsNone(val_add.dtype) - self.assertIsNone(val_mul.shape) - self.assertIsNone(val_mul.dtype) - self.assertIsNone(initializer.shape) - self.assertIsNone(initializer.dtype) - - # Perform shape inference - result = shape_inference.ShapeInferencePass()(model) - self.assertTrue(result.modified) - self.assertEqual(result.model.graph.node(0).outputs[0].shape, ir.Shape((big_dim, 2))) - self.assertEqual(result.model.graph.node(0).outputs[0].dtype, ir.DataType.FLOAT) - self.assertEqual(result.model.graph.node(1).outputs[0].shape, ir.Shape((big_dim, 2))) - self.assertEqual(result.model.graph.node(1).outputs[0].dtype, ir.DataType.FLOAT) - self.assertEqual( - result.model.graph.initializers["initializer"].shape, ir.Shape((1, 2)) - ) - self.assertEqual( - result.model.graph.initializers["initializer"].dtype, ir.DataType.FLOAT - ) - self.assertEqual(result.model.graph.outputs[0].shape, ir.Shape((big_dim, 2))) - self.assertEqual(result.model.graph.outputs[0].dtype, ir.DataType.FLOAT) - - # Check that the initializer correctly appears in the result - self.assertEqual(len(result.model.graph.inputs), 2) - self.assertEqual(len(result.model.graph.initializers), 2) - np.testing.assert_array_equal( - result.model.graph.initializers["input_b"].const_value.numpy(), - np.array([[42]] * big_dim, dtype=np.float32), - strict=True, - ) - self.assertEqual( - result.model.graph.initializers["input_b"].const_value.dtype, - ir.DataType.FLOAT, - ) - np.testing.assert_array_equal( - result.model.graph.initializers["initializer"].const_value.numpy(), - np.array([[2.0, 3.0]], dtype=np.float32), - strict=True, - ) - self.assertEqual( - result.model.graph.initializers["initializer"].const_value.dtype, - ir.DataType.FLOAT, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/topological_sort.py b/onnxscript/ir/passes/common/topological_sort.py deleted file mode 100644 index 9be183cf01..0000000000 --- a/onnxscript/ir/passes/common/topological_sort.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Pass for topologically sorting the graphs.""" - -from __future__ import annotations - -__all__ = [ - "TopologicalSortPass", -] - - -from onnxscript import ir - - -class TopologicalSortPass(ir.passes.InPlacePass): - """Topologically sort graphs and functions in a model.""" - - def call(self, model: ir.Model) -> ir.passes.PassResult: - original_nodes = list(model.graph) - model.graph.sort() - sorted_nodes = list(model.graph) - for function in model.functions.values(): - original_nodes.extend(function) - function.sort() - sorted_nodes.extend(function) - - # Compare node orders to determine if any changes were made - modified = False - for node, new_node in zip(original_nodes, sorted_nodes): - if node is not new_node: - modified = True - break - return ir.passes.PassResult(model=model, modified=modified) diff --git a/onnxscript/ir/passes/common/topological_sort_test.py b/onnxscript/ir/passes/common/topological_sort_test.py deleted file mode 100644 index 8680761f1e..0000000000 --- a/onnxscript/ir/passes/common/topological_sort_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Unit tests for the TopologicalSortPass.""" - -import unittest - -from onnxscript import ir -from onnxscript.ir.passes.common import topological_sort - - -class TopologicalSortPassTest(unittest.TestCase): - def setUp(self): - self.node_a = ir.node("A", inputs=[], name="node_a") - self.node_b = ir.node("B", inputs=self.node_a.outputs, name="node_b") - self.node_c = ir.node("C", inputs=self.node_b.outputs, name="node_c") - - def test_topological_sort_modified_true(self): - graph = ir.Graph( - inputs=self.node_a.inputs, - outputs=self.node_c.outputs, - nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes - name="test_graph", - ) - model = ir.Model(graph, ir_version=10) - result = topological_sort.TopologicalSortPass()(model) - self.assertTrue(result.modified) - self.assertEqual( - tuple(result.model.graph), - (self.node_a, self.node_b, self.node_c), - ) - - def test_topological_sort_modified_false(self): - """Test that modified is False when the input model is already sorted.""" - sorted_graph = ir.Graph( - inputs=self.node_a.inputs, - outputs=self.node_c.outputs, - nodes=[self.node_a, self.node_b, self.node_c], # Sorted nodes - name="test_graph", - ) - sorted_model = ir.Model(sorted_graph, ir_version=10) - result = topological_sort.TopologicalSortPass()(sorted_model) - self.assertFalse(result.modified) - self.assertEqual( - tuple(result.model.graph), - (self.node_a, self.node_b, self.node_c), - ) - - def test_topological_sort_on_functions(self): - """Test that TopologicalSortPass works on functions in a model.""" - # Create a function with unsorted nodes - func_graph = ir.Graph( - inputs=self.node_a.inputs, - outputs=self.node_c.outputs, - nodes=[self.node_c, self.node_b, self.node_a], # Unsorted nodes - ) - function = ir.Function( - domain="test_domain", - name="test_function", - graph=func_graph, - attributes=[], - ) - - # Create a model with the function - graph = ir.Graph( - inputs=[], - outputs=[], - nodes=[], - name="test_graph", - ) - model = ir.Model(graph, ir_version=10, functions=[function]) - - # Apply the TopologicalSortPass - result = topological_sort.TopologicalSortPass()(model) - - # Verify that the nodes in the function are sorted - sorted_func_nodes = (self.node_a, self.node_b, self.node_c) - self.assertTrue(result.modified) - self.assertEqual( - tuple(result.model.functions[function.identifier()]), - sorted_func_nodes, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/passes/common/unused_removal.py b/onnxscript/ir/passes/common/unused_removal.py deleted file mode 100644 index fe9cc28b19..0000000000 --- a/onnxscript/ir/passes/common/unused_removal.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -__all__ = [ - "RemoveUnusedNodesPass", - "RemoveUnusedFunctionsPass", - "RemoveUnusedOpsetsPass", -] - -import logging - -import onnx - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -def _remove_unused_optional_outputs( - node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int -) -> None: - try: - if node.domain not in {"", "onnx.ai"}: - return - op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain) - except Exception: # pylint: disable=broad-exception-caught - logger.info( - "Failed to get schema for %s, skipping optional output removal", - node, - stack_info=True, - ) - return - - if node.op_type == "BatchNormalization": - # BatchNormalization op has 3 outputs: Y, running_mean, running_var - # If running_mean and running_var are not used, remove them, and the training_mode attribute - def is_used_output(i: int) -> bool: - if i < len(node.outputs): - val = node.outputs[i] - return val in graph_outputs or bool(val.uses()) - return False - - if is_used_output(1) or is_used_output(2): - return - if len(node.outputs) > 1: - node.outputs[1].name = "" - if len(node.outputs) > 2: - node.outputs[2].name = "" - node.attributes.pop("training_mode", None) - return - - optional_info = [] - for o in op_schema.outputs: - # Current ops do not have optional outputs if they have variable number of outputs - if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic: - return - optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional) - # If no optional outputs in spec, skip delete operations - if len([o == 1 for o in optional_info]) == 0: - return - - for i, out in enumerate(node.outputs): - if out not in graph_outputs and (not out.uses()) and optional_info[i] is True: - out.name = "" - - -def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int: - graph_outputs = frozenset(function_or_graph.outputs) - onnx_opset_version = function_or_graph.opset_imports.get("", None) - count = 0 - for node in reversed(function_or_graph): - removable = True - for output in node.outputs: - if output in graph_outputs or output.uses(): - removable = False - break - if removable: - function_or_graph.remove(node, safe=True) - count += 1 - else: - if onnx_opset_version is not None: - _remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version) - for attr in node.attributes.values(): - if not isinstance(attr, ir.Attr): - continue - if attr.type == ir.AttributeType.GRAPH: - count += _remove_unused_nodes_in_graph_like(attr.as_graph()) - elif attr.type == ir.AttributeType.GRAPHS: - for graph in attr.as_graphs(): - count += _remove_unused_nodes_in_graph_like(graph) - return count - - -class RemoveUnusedNodesPass(ir.passes.InPlacePass): - """Pass for removing unused nodes and initializers (dead code elimination). - - This pass does not modify the model signature (inputs and outputs). It ensures - that unused nodes and initializers are removed while preserving the original - contract of the model. - """ - - def call(self, model: ir.Model) -> ir.passes.PassResult: - count = _remove_unused_nodes_in_graph_like(model.graph) - graph_outputs = frozenset(model.graph.outputs) - graph_inputs = frozenset(model.graph.inputs) - initializers = model.graph.initializers - for init in list(initializers.values()): - if not (init.uses() or init in graph_outputs or init in graph_inputs): - assert init.name is not None - del initializers[init.name] - count += 1 - for function in model.functions.values(): - count += _remove_unused_nodes_in_graph_like(function) - if count: - logger.info("Removed %s unused nodes", count) - return ir.passes.PassResult(model, modified=bool(count)) - - -class RemoveUnusedFunctionsPass(ir.passes.InPlacePass): - def __init__(self): - super().__init__() - self._used: set[ir.OperatorIdentifier] | None = None - - def call(self, model: ir.Model) -> ir.passes.PassResult: - self._used = set() - for node in ir.traversal.RecursiveGraphIterator(model.graph): - self._call_node(model, node) - - # Update the model to remove unused functions - unused = set(model.functions) - self._used - if not unused: - logger.info("No unused functions to remove") - return ir.passes.PassResult(model, modified=False) - - for op_identifier in unused: - del model.functions[op_identifier] - - logger.info("Removed %s unused functions", len(unused)) - logger.debug("Functions left: %s", list(model.functions)) - logger.debug("Functions removed: %s", unused) - - self._used = None - return ir.passes.PassResult(model, modified=bool(unused)) - - def _call_function(self, model: ir.Model, function: ir.Function) -> None: - assert self._used is not None - if function.identifier() in self._used: - # The function and its nodes are already recorded as used - return - self._used.add(function.identifier()) - for node in ir.traversal.RecursiveGraphIterator(function): - self._call_node(model, node) - - def _call_node(self, model: ir.Model, node: ir.Node) -> None: - op_identifier = node.op_identifier() - if op_identifier not in model.functions: - return - self._call_function(model, model.functions[op_identifier]) - - -class RemoveUnusedOpsetsPass(ir.passes.InPlacePass): - """Remove unused opset imports from the model and functions. - - Attributes: - process_functions: Whether to process functions in the model. If True, the pass will - remove unused opset imports from functions as well. If False, only the main graph - will be processed. - """ - - def __init__(self, process_functions: bool = True): - super().__init__() - self.process_functions = process_functions - - def _process_graph_like( - self, graph_like: ir.Graph | ir.Function, used_domains: set[str] - ) -> bool: - for node in ir.traversal.RecursiveGraphIterator(graph_like): - used_domains.add(node.domain) - unused = set(graph_like.opset_imports) - used_domains - for domain in unused: - del graph_like.opset_imports[domain] - return bool(unused) - - def call(self, model: ir.Model) -> ir.passes.PassResult: - # Record domains of all functions - used_domains = {""} # By default always retain the onnx (default) domain - for function in model.functions.values(): - used_domains.add(function.domain) - modified = self._process_graph_like(model.graph, used_domains=used_domains) - - if self.process_functions: - for function in model.functions.values(): - modified |= self._process_graph_like(function, used_domains={""}) - - return ir.passes.PassResult(model, modified=modified) diff --git a/onnxscript/ir/passes/common/unused_removal_test.py b/onnxscript/ir/passes/common/unused_removal_test.py deleted file mode 100644 index 04d554555f..0000000000 --- a/onnxscript/ir/passes/common/unused_removal_test.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest - -import onnx -import parameterized - -import onnxscript.optimizer -from onnxscript import ir - - -@parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) -class RemoveUnusedTest(unittest.TestCase): - using_ir: bool - - def remove_unused_nodes(self, model: onnx.ModelProto): - if self.using_ir: - model_ir = ir.serde.deserialize_model(model) - onnxscript.optimizer.remove_unused_nodes(model_ir) - model = ir.serde.serialize_model(model_ir) - return model - onnxscript.optimizer.remove_unused_nodes(model) - return model - - def test_remove_unused_nodes(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) { - two = Constant () - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - - def test_remove_unused_initializers(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - self.assertEqual(len(model.graph.initializer), 1) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.initializer), 0) - - def test_unused_initialized_inputs_are_kept(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x, float[N] two) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.input), 2) - self.assertEqual(len(model.graph.initializer), 1) - - def test_unused_inputs_are_not_removed(self): - # preserve inputs as part of interface - model = onnx.parser.parse_model( - """ - - agraph (float[N] x, float[N] two) => (float[N] z) - { - four = Add(two, two) - z = Mul(x, x) - } - """ - ) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "Mul") - self.assertEqual(len(model.graph.input), 2) - - def test_partially_used_nodes(self): - model = onnx.parser.parse_model( - """ - - agraph (float[N] x) => (float[M] z) { - w1, w2, w3 = Split (x) - z = Mul(w3, w3) - } - """ - ) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 2) - self.assertEqual(model.graph.node[0].op_type, "Split") - - def test_remove_unused_optional_outputs_maxpool(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) { - z, indices = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 2) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(model.graph.node[0].output, ["z"]) - - def test_remove_unused_optional_outputs_dropout_in_function(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] z) - { - z = pkg.custom.afunction (x) - } - - afunction (x) => (z) - { - z, indices = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.functions), 1) - self.assertEqual(len(model.functions[0].node), 1) - self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - self.assertEqual(len(model.functions[0].node[0].output), 2) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.functions), 1) - self.assertEqual(len(model.functions[0].node), 1) - self.assertEqual(model.functions[0].node[0].op_type, "MaxPool") - self.assertEqual(model.functions[0].node[0].output, ["z"]) - - def test_remove_used_optional_outputs_maxpool(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 1, 5, 5] x) => (float[1, 1, 5, 5] y, float[1, 1, 5, 5] z) { - y, z = MaxPool (x) - } - """ - ) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(len(model.graph.node[0].output), 2) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "MaxPool") - self.assertEqual(model.graph.node[0].output, ["y", "z"]) - - def test_remove_multiple_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(list(model.graph.node[2].output), ["z"]) - - def test_remove_trailing_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] mean) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(list(model.graph.node[2].output), ["z", "mean"]) - - def test_avoid_remove_non_trailing_unused_optional_outputs_layernorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x) => (float[1, 3, 5, 5] z, float[1, 3, 5, 5] InvStdDev) { - scale = Constant () - B = Constant () - z, mean, InvStdDev = LayerNormalization(x, scale, B) - } - """ - ) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(len(model.graph.node[2].output), 3) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 3) - self.assertEqual(model.graph.node[2].op_type, "LayerNormalization") - self.assertEqual(list(model.graph.node[2].output), ["z", "", "InvStdDev"]) - - def test_remove_trailing_unused_optional_outputs_batchnorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z) { - z, mean_out, var_out = BatchNormalization (x, scale, B, mean, var) - } - """ - ) - self.assertEqual(len(model.graph.node[0].attribute), 1) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") - # Check that both the mean/var outputs are removed, and training_mode attribute is removed. - self.assertEqual(list(model.graph.node[0].output), ["z"]) - self.assertEqual(len(model.graph.node[0].attribute), 0) - - def test_avoid_remove_used_optional_outputs_batchnorm(self): - model = onnx.parser.parse_model( - """ - - agraph (float[1, 3, 5, 5] x, float[3] scale, float[3] B) => (float[1, 3, 5, 5] z, float[3] mean_out, float[3] var_out) { - z, mean_out, var_out = BatchNormalization (x, scale, B, mean, var) - } - """ - ) - self.assertEqual(len(model.graph.node[0].attribute), 1) - model = self.remove_unused_nodes(model) - self.assertEqual(len(model.graph.node), 1) - self.assertEqual(model.graph.node[0].op_type, "BatchNormalization") - # Check that the mean/var outputs are NOT removed, and training_mode attribute is NOT removed. - self.assertEqual(list(model.graph.node[0].output), ["z", "mean_out", "var_out"]) - self.assertEqual(len(model.graph.node[0].attribute), 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py deleted file mode 100644 index b5be445aef..0000000000 --- a/onnxscript/ir/serde.py +++ /dev/null @@ -1,1725 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Serialize and deserialize the intermediate representation to/from ONNX protos.""" - -# NOTES for developers: -# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead. -# -# NOTE: Protobuf serialization -# Initializing a protobuf message with initialized protobuf messages incurs -# a copy and is slow. Instead, use proto.add() to add to a repeated field. -# or initialize the message first and then set the fields if the fields are -# plain Python objects. - -from __future__ import annotations - -import functools -import typing - -__all__ = [ - # Tensors - "TensorProtoTensor", - # Deserialization - "from_proto", - "from_onnx_text", - "deserialize_attribute", - "deserialize_dimension", - "deserialize_function", - "deserialize_graph", - "deserialize_metadata_props", - "deserialize_model", - "deserialize_node", - "deserialize_opset_import", - "deserialize_tensor", - "deserialize_tensor_shape", - "deserialize_type_proto_for_shape", - "deserialize_type_proto_for_type", - "deserialize_value_info_proto", - # Serialization - "to_proto", - "serialize_attribute_into", - "serialize_attribute", - "serialize_dimension_into", - "serialize_function_into", - "serialize_function", - "serialize_graph_into", - "serialize_graph", - "serialize_model_into", - "serialize_model", - "serialize_node_into", - "serialize_node", - "serialize_shape_into", - "serialize_reference_attribute_into", - "serialize_tensor_into", - "serialize_tensor", - "serialize_type_into", - "serialize_type", - "serialize_value_into", - "serialize_value", - "SerdeError", -] - -import collections -import logging -import os -from typing import Any, Callable, List, Mapping, Sequence - -import numpy as np -import onnx -import onnx.external_data_helper - -from onnxscript.ir import _core, _enums, _protocols, _type_casting - -if typing.TYPE_CHECKING: - import google.protobuf.internal.containers as proto_containers - import numpy.typing as npt - -logger = logging.getLogger(__name__) - -_PLEASE_CONTRIBUTE = ( - "Please contribute by creating a PR at https://github.com/microsoft/onnxscript." -) -_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( - 10 # ONNX IR version where value info in functions was introduced -) -_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names" -_T = typing.TypeVar("_T", bound=Callable[..., Any]) - - -class SerdeError(RuntimeError): - """Error during serialization or deserialization.""" - - -def _capture_errors(arg_capturer: Callable[..., str]) -> Callable[[_T], _T]: - """Decorator to capture errors and display the stack.""" - - def decorator(func: _T) -> _T: - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - try: - return func(*args, **kwargs) - except Exception as e: - raise SerdeError( - f"Error calling {func.__name__} with: {arg_capturer(*args, **kwargs)}" - ) from e - - return wrapper # type: ignore - - return decorator - - -def _little_endian_dtype(dtype) -> np.dtype: - """Create a small endian dtype on all platforms. - - This is useful because ONNX always stores raw_data in small endian. On big - endian platforms, we still need to interpret the raw_data in small endian. - """ - return np.dtype(dtype).newbyteorder("<") - - -def _unflatten_complex( - array: npt.NDArray[np.float32 | np.float64], -) -> npt.NDArray[np.complex64 | np.complex128]: - """Convert the real representation of a complex dtype to the complex dtype.""" - return array[::2] + 1j * array[1::2] - - -@typing.overload -def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto( # type: ignore[overload-overlap] - proto: onnx.TensorShapeProto.Dimension, -) -> tuple[int | _core.SymbolicDim, str | None]: ... -@typing.overload -def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap] -@typing.overload -def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap] - - -def from_proto(proto: object) -> object: - """Deserialize an ONNX proto message to an IR object.""" - if isinstance(proto, onnx.ModelProto): - return deserialize_model(proto) - if isinstance(proto, onnx.GraphProto): - return deserialize_graph(proto) - if isinstance(proto, onnx.NodeProto): - return deserialize_node(proto) - if isinstance(proto, onnx.TensorProto): - return deserialize_tensor(proto) - if isinstance(proto, onnx.AttributeProto): - return deserialize_attribute(proto) - if isinstance(proto, onnx.ValueInfoProto): - return deserialize_value_info_proto(proto, None) - if isinstance(proto, onnx.TypeProto): - return _core.TypeAndShape( - deserialize_type_proto_for_type(proto), - deserialize_type_proto_for_shape(proto), - ) - if isinstance(proto, onnx.FunctionProto): - return deserialize_function(proto) - if isinstance(proto, onnx.TensorShapeProto): - return deserialize_tensor_shape(proto) - if isinstance(proto, onnx.TensorShapeProto.Dimension): - return deserialize_dimension(proto) - if isinstance(proto, Sequence) and all( - isinstance(p, onnx.OperatorSetIdProto) for p in proto - ): - return deserialize_opset_import(proto) - if isinstance(proto, Sequence) and all( - isinstance(p, onnx.StringStringEntryProto) for p in proto - ): - return deserialize_metadata_props(proto) - raise NotImplementedError( - f"Deserialization of {type(proto)} in from_proto is not implemented. " - "Use a specific ir.serde.deserialize* function instead." - ) - - -def from_onnx_text(model_text: str, /) -> _core.Model: - """Convert the ONNX textual representation to an IR model. - - Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html - """ - proto = onnx.parser.parse_model(model_text) - return deserialize_model(proto) - - -@typing.overload -def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap] -@typing.overload -def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap] - - -def to_proto(ir_object: object) -> object: - """Serialize an IR object to a proto.""" - if isinstance(ir_object, _protocols.ModelProtocol): - return serialize_model(ir_object) - if isinstance(ir_object, _protocols.GraphProtocol): - return serialize_graph(ir_object) - if isinstance(ir_object, _protocols.NodeProtocol): - return serialize_node(ir_object) - if isinstance(ir_object, _protocols.TensorProtocol): - return serialize_tensor(ir_object) - if isinstance(ir_object, _protocols.ValueProtocol): - return serialize_value(ir_object) - if isinstance(ir_object, _protocols.AttributeProtocol): - return serialize_attribute(ir_object) - if isinstance(ir_object, _protocols.ReferenceAttributeProtocol): - return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object) - if isinstance(ir_object, _protocols.TypeProtocol): - return serialize_type_into(onnx.TypeProto(), ir_object) - if isinstance(ir_object, _protocols.GraphViewProtocol): - return serialize_graph(ir_object) - if isinstance(ir_object, _protocols.FunctionProtocol): - return serialize_function(ir_object) - raise NotImplementedError( - f"Serialization of {type(ir_object)} in to_proto is not implemented. " - "Use a specific ir.serde.serialize* function instead." - ) - - -class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors - """A tensor initialized from a tensor proto.""" - - __slots__ = ("_proto",) - - def __init__(self, proto: onnx.TensorProto) -> None: - super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props)) - self._proto = proto - - @property - def name(self) -> str: - return self._proto.name - - @name.setter - def name(self, value: str | None) -> None: - if value is None: - self._proto.ClearField("name") - else: - self._proto.name = value - - @property - def shape(self) -> _core.Shape: - return _core.Shape(self._proto.dims, frozen=True) - - @property - def dtype(self) -> _enums.DataType: - return _enums.DataType(self._proto.data_type) - - @property # type: ignore[misc] - def doc_string(self) -> str: - return self._proto.doc_string - - @property - def raw(self) -> onnx.TensorProto: - return self._proto - - def __repr__(self) -> str: - if self.size <= 10: - tensor_lines = repr(self.numpy()).split("\n") - tensor_text = " ".join(line.strip() for line in tensor_lines) - return f"{self._repr_base()}({tensor_text}, name={self.name!r})" - return f"{self._repr_base()}(name={self.name!r})" - - def __array__(self, dtype: Any = None) -> np.ndarray: - """Return the tensor as a numpy array, compatible with np.array.""" - return self.numpy().__array__(dtype) - - def __dlpack__(self, *, stream: Any = None) -> Any: - return self.numpy().__dlpack__(stream=stream) - - def __dlpack_device__(self) -> tuple[int, int]: - return self.numpy().__dlpack_device__() - - def numpy(self) -> np.ndarray: - """Return the tensor as a numpy array. - - This is an improved version of onnx.numpy_helper.to_array. - It first reads the data using the dtype corresponding to the tensor - proto data field, then converts it to the correct dtype and shape. - Special cases are bfloat16, complex and int4 where we need to - reinterpret the data. Other types can simply be casted. - - When the data type is not supported by numpy, the dtypes from the ``ml_dtype`` - package are used. The values can be reinterpreted as bit representations - using the ``.view()`` method. - - When the data type is a string, this method returns a numpy array - of bytes instead of a numpy array of strings, to follow the ONNX - specification. - - External tensors are not supported by this class. Use - :class:`onnxscript.ir.ExternalTensor` instead. - - Raises: - ValueError: If the data type is UNDEFINED. - """ - dtype = self.dtype - if dtype == _enums.DataType.UNDEFINED: - raise ValueError("Cannot convert UNDEFINED tensor to numpy array.") - if self._proto.data_location == onnx.TensorProto.EXTERNAL: - raise ValueError( - "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead." - ) - - if self._proto.HasField("raw_data"): - array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")) - # Cannot return now, because we may need to unpack 4bit tensors - elif dtype == _enums.DataType.STRING: - return np.array(self._proto.string_data).reshape(self._proto.dims) - elif self._proto.int32_data: - array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32)) - if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}: - # Reinterpret the int32 as float16 or bfloat16 - array = array.astype(np.uint16).view(dtype.numpy()) - elif dtype in { - _enums.DataType.FLOAT8E4M3FN, - _enums.DataType.FLOAT8E4M3FNUZ, - _enums.DataType.FLOAT8E5M2, - _enums.DataType.FLOAT8E5M2FNUZ, - }: - array = array.astype(np.uint8).view(dtype.numpy()) - elif self._proto.int64_data: - array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64)) - elif self._proto.uint64_data: - array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64)) - elif self._proto.float_data: - array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32)) - if dtype == _enums.DataType.COMPLEX64: - array = _unflatten_complex(array) - elif self._proto.double_data: - array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64)) - if dtype == _enums.DataType.COMPLEX128: - array = _unflatten_complex(array) - else: - # Empty tensor - if not self._proto.dims: - # When dims not precent and there is no data, we return an empty array - return np.array([], dtype=dtype.numpy()) - else: - # Otherwise we return a size 0 array with the correct shape - return np.zeros(self._proto.dims, dtype=dtype.numpy()) - - if dtype == _enums.DataType.INT4: - return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims) - elif dtype == _enums.DataType.UINT4: - return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims) - elif dtype == _enums.DataType.FLOAT4E2M1: - return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims) - else: - # Otherwise convert to the correct dtype and reshape - # Note we cannot use view() here because the storage dtype may not be the same size as the target - return array.astype(dtype.numpy()).reshape(self._proto.dims) - - def tobytes(self) -> bytes: - """Return the tensor as a byte string conformed to the ONNX specification, in little endian. - - Raises: - ValueError: If the tensor is a string tensor or an external tensor. - ValueError: If the tensor is of UNDEFINED data type. - """ - if self._proto.data_location == onnx.TensorProto.EXTERNAL: - raise ValueError( - "Cannot convert external tensor to bytes. Use ir.ExternalTensor instead." - ) - if self.dtype == _enums.DataType.STRING: - raise ValueError("Cannot convert string tensor to bytes.") - if self.dtype == _enums.DataType.UNDEFINED: - raise ValueError("Cannot convert UNDEFINED tensor to bytes.") - - if self._proto.HasField("raw_data"): - return self._proto.raw_data - if self._proto.float_data: - return np.array( - self._proto.float_data, dtype=_little_endian_dtype(np.float32) - ).tobytes() - if self._proto.int32_data: - array = np.array(self._proto.int32_data, dtype=np.int32) - if self.dtype in { - _enums.DataType.INT16, - _enums.DataType.UINT16, - _enums.DataType.FLOAT16, - _enums.DataType.BFLOAT16, - }: - return array.astype(_little_endian_dtype(np.uint16)).tobytes() - if self.dtype in { - _enums.DataType.INT8, - _enums.DataType.UINT8, - _enums.DataType.BOOL, - _enums.DataType.FLOAT8E4M3FN, - _enums.DataType.FLOAT8E4M3FNUZ, - _enums.DataType.FLOAT8E5M2, - _enums.DataType.FLOAT8E5M2FNUZ, - _enums.DataType.INT4, - _enums.DataType.UINT4, - _enums.DataType.FLOAT4E2M1, - }: - # uint4 and int4 values are already packed, even when stored as int32 - # so we don't need to pack them again - return array.astype(_little_endian_dtype(np.uint8)).tobytes() - assert self.dtype == _enums.DataType.INT32 - return array.tobytes() - if self._proto.int64_data: - return np.array( - self._proto.int64_data, dtype=_little_endian_dtype(np.int64) - ).tobytes() - if self._proto.double_data: - return np.array( - self._proto.double_data, dtype=_little_endian_dtype(np.float64) - ).tobytes() - if self._proto.uint64_data: - array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64)) - if self.dtype == _enums.DataType.UINT32: - return array.astype(_little_endian_dtype(np.uint32)).tobytes() - assert self.dtype == _enums.DataType.UINT64 - return array.tobytes() - # The repeating fields can be empty and still valid. - # For example, int32_data can be empty and still be a valid tensor. - return b"" - - -def _get_field(proto: Any, field: str) -> Any: - if proto.HasField(field): - return getattr(proto, field) - return None - - -# Deserialization - - -def deserialize_opset_import( - protos: Sequence[onnx.OperatorSetIdProto], -) -> dict[str, int]: - return {opset.domain: opset.version for opset in protos} - - -def _parse_experimental_function_value_info_name( - name: str, -) -> tuple[str, str, str] | None: - """Get the function domain, name and value name if the value info is for a function. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - - Args: - name: The name stored in the value info. - - Returns: - A tuple of the function domain, function name and value name if the value info is for a function. - None otherwise. - """ - parts = name.split("/") - expected_parts = 2 - if len(parts) != expected_parts: - return None - function, value_name = parts - parts = function.split("::") - if len(parts) != expected_parts: - return None - # NOTE: There will not be overload because overloads are introduced in ONNX IR v10, which also - # introduces the ValueInfoProto for functions - function_domain, function_name = parts - return function_domain, function_name, value_name - - -def deserialize_model(proto: onnx.ModelProto) -> _core.Model: - graph = _deserialize_graph(proto.graph, []) - graph.opset_imports.update(deserialize_opset_import(proto.opset_import)) - - functions = [] - for func in proto.functions: - functions.append(deserialize_function(func)) - - model = _core.Model( - graph, - ir_version=proto.ir_version, - producer_name=_get_field(proto, "producer_name"), - producer_version=_get_field(proto, "producer_version"), - domain=_get_field(proto, "domain"), - model_version=_get_field(proto, "model_version"), - doc_string=_get_field(proto, "doc_string"), - functions=functions, - meta_data_props=deserialize_metadata_props(proto.metadata_props), - ) - - # Handle experimental value info for functions created by the dynamo exporter in IR version 9 - if model.ir_version < _FUNCTION_VALUE_INFO_SUPPORTED_VERSION: - _deserialized_experimental_value_info_for_function_ir9( - model.functions, proto.graph.value_info - ) - - return model - - -def _deserialized_experimental_value_info_for_function_ir9( - functions: Mapping[_protocols.OperatorIdentifier, _core.Function], - value_info_protos: Sequence[onnx.ValueInfoProto], -) -> None: - """Deserialize value info for functions when they are stored in an experimental format. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - """ - # Parse value info for functions from the main graph - function_value_value_info_mapping: collections.defaultdict[ - _protocols.OperatorIdentifier, - dict[str, onnx.ValueInfoProto], - ] = collections.defaultdict(dict) - for value_info_proto in value_info_protos: - if ( - parsed := _parse_experimental_function_value_info_name(value_info_proto.name) - ) is None: - continue - function_domain, function_name, value_name = parsed - function_overload = "" - # TODO(justinchuby): Create a constructor for OperatorIdentifier so we don't create tuples manually - function_id = (function_domain, function_name, function_overload) - function = functions.get(function_id) - if function is None: - # Function not found - logger.debug( - "Function with ID '%s' not found in model functions. Value info '%s' will be ignored.", - function_id, - value_info_proto.name, - ) - continue - function_value_value_info_mapping[function_id][value_name] = value_info_proto - for function_id, function in functions.items(): - for input in function.inputs: - if input.name in function_value_value_info_mapping[function_id]: - deserialize_value_info_proto( - function_value_value_info_mapping[function_id][input.name], input - ) - for node in function: - for output in node.outputs: - if output.name in function_value_value_info_mapping[function_id]: - deserialize_value_info_proto( - function_value_value_info_mapping[function_id][output.name], - output, - ) - # The function outputs are handled as well because they are also node outputs - - -def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph: - """Deserialize a graph proto, recursively if needed. - - Args: - proto: The graph proto to deserialize. - - Returns: - IR Graph. - - .. versionadded:: 0.3 - Support for *quantization_annotation* is added. - """ - return _deserialize_graph(proto, []) - - -@_capture_errors(lambda proto, scoped_values: proto.name) -def _deserialize_graph( - proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]] -) -> _core.Graph: - """Deserialize a graph proto, recursively if needed. - - Args: - proto: The graph proto to deserialize. - scoped_values: A list of dictionaries mapping value names to their corresponding Value objects. - Every time we enter a new graph, a new scope is created and appended to this list to include - all values defined in the scope. - scoped_value_info: A list of dictionaries mapping value names to their corresponding ValueInfoProto. - - Returns: - IR Graph. - """ - # Process TensorAnnotation for quantization - quantization_annotations = { - annotation.tensor_name: annotation for annotation in proto.quantization_annotation - } - - # Create values for initializers and inputs - initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer] - inputs = [_core.Input(info.name) for info in proto.input] - for info, value in zip(proto.input, inputs): - deserialize_value_info_proto(info, value) - - # Add TensorAnnotation for inputs if they exist - if value.name in quantization_annotations: - _deserialize_quantization_annotation(quantization_annotations[value.name], value) - - # Initialize the values dictionary for this graph scope with the inputs and initializers - values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] - - # Enter the graph scope by pushing the values for this scope to the stack - scoped_values.append(values) - - initializer_values = [] - for i, tensor in enumerate(initializer_tensors): - initializer_name = tensor.name - if not initializer_name: - logger.warning( - "Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.", - i, - ) - continue - if initializer_name in values: - # The initializer is for an input - initializer_value = values[initializer_name] - initializer_value.const_value = tensor - else: - # The initializer is for some other value. Create this value first - initializer_value = _core.Value( - None, - index=None, - name=initializer_name, - # Include shape and type even if the shape or type is not provided as ValueInfoProto. - # Users expect initialized values to have shape and type information. - type=_core.TensorType(tensor.dtype), - shape=tensor.shape, # type: ignore[arg-type] - const_value=tensor, - ) - if initializer_value.name in quantization_annotations: - _deserialize_quantization_annotation( - quantization_annotations[initializer_value.name], initializer_value - ) - values[initializer_name] = initializer_value - initializer_values.append(initializer_value) - - # Build the value info dictionary to allow for quick lookup for this graph scope - value_info = {info.name: info for info in proto.value_info} - - # Deserialize nodes with all known values - nodes = [ - _deserialize_node(node, scoped_values, value_info, quantization_annotations) - for node in proto.node - ] - - outputs = [] - for info in proto.output: - # Fill in values for graph outputs - output_name = info.name - if output_name not in values: - # Handle (invalid) graph outputs that do not have any producers - logger.warning( - "Output '%s' is not produced by any node. The graph has an invalid output", - output_name, - ) - value = _core.Value(name=output_name) - else: - # A valid, normal graph output - value = values[output_name] - # Fill in shape/type information - deserialize_value_info_proto(info, value) - outputs.append(value) - - # Exit the graph scope by popping the values for this scope from the stack - scoped_values.pop() - - return _core.Graph( - inputs, - outputs, - nodes=nodes, - initializers=initializer_values, - doc_string=_get_field(proto, "doc_string"), - name=_get_field(proto, "name"), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - - -@_capture_errors(lambda proto: proto.name) -def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: - inputs = [_core.Input(name) for name in proto.input] - values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc] - value_info = {info.name: info for info in getattr(proto, "value_info", [])} - - # TODO(justinchuby): Handle unsorted nodes - nodes = [ - _deserialize_node(node, [values], value_info=value_info, quantization_annotations={}) - for node in proto.node - ] - outputs = [values[name] for name in proto.output] - graph = _core.Graph( - inputs, - outputs, - nodes=nodes, - initializers=(), - doc_string=_get_field(proto, "doc_string"), - opset_imports=deserialize_opset_import(proto.opset_import), - name=( - f"{proto.name}_{proto.domain}" + f"__{proto.overload}" - if hasattr(proto, "overload") and proto.overload - else "" - ), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto] - # Attributes without defaults - attributes += [ - _core.Attr(name, _enums.AttributeType.UNDEFINED, None) for name in proto.attribute - ] - return _core.Function( - domain=proto.domain, - name=proto.name, - overload=getattr(proto, "overload", ""), - graph=graph, - attributes=typing.cast(List[_core.Attr], attributes), - ) - - -@_capture_errors(lambda proto, value: str(proto)) -def deserialize_value_info_proto( - proto: onnx.ValueInfoProto, value: _core.Value | None -) -> _core.Value: - if value is None: - value = _core.Value(name=proto.name) - value.shape = deserialize_type_proto_for_shape(proto.type) - value.type = deserialize_type_proto_for_type(proto.type) - metadata_props = deserialize_metadata_props(proto.metadata_props) - if metadata_props is not None: - value.metadata_props.update(metadata_props) - value.doc_string = _get_field(proto, "doc_string") - return value - - -@_capture_errors(lambda proto, value: str(proto)) -def _deserialize_quantization_annotation( - proto: onnx.TensorAnnotation, value: _core.Value -) -> None: - """Deserialize a quantization_annotation as TensorAnnotation into a Value. - - This function is marked private because we don't expect users to call it directly. - """ - value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps( - proto.quant_parameter_tensor_names - ) - - -@_capture_errors(str) -def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape: - # This logic handles when the shape is [] as well - dim_protos = proto.dim - deserialized_dim_denotations = [ - deserialize_dimension(dim_proto) for dim_proto in dim_protos - ] - dims = [dim for dim, _ in deserialized_dim_denotations] - denotations = [denotation for _, denotation in deserialized_dim_denotations] - return _core.Shape(dims, denotations=denotations, frozen=True) - - -@_capture_errors(str) -def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None: - if proto.HasField("tensor_type"): - if (shape_proto := _get_field(proto.tensor_type, "shape")) is None: - return None - return deserialize_tensor_shape(shape_proto) - if proto.HasField("sparse_tensor_type"): - if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None: - return None - return deserialize_tensor_shape(shape_proto) - if proto.HasField("sequence_type"): - if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None: - return None - return deserialize_type_proto_for_shape(elem_type) - if proto.HasField("optional_type"): - if (elem_type := _get_field(proto.optional_type, "elem_type")) is None: - return None - return deserialize_type_proto_for_shape(elem_type) - if proto.HasField("map_type"): - # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}") - - return None - - -@_capture_errors(str) -def deserialize_type_proto_for_type( - proto: onnx.TypeProto, -) -> _protocols.TypeProtocol | None: - denotation = _get_field(proto, "denotation") - if proto.HasField("tensor_type"): - if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None: - return None - return _core.TensorType(_enums.DataType(elem_type), denotation=denotation) - if proto.HasField("sparse_tensor_type"): - if (elem_type := _get_field(proto.sparse_tensor_type, "elem_type")) is None: - return None - return _core.SparseTensorType(_enums.DataType(elem_type), denotation=denotation) - if proto.HasField("sequence_type"): - # FIXME(justinchuby): Allow nested types being None - if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None: - raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}") - nested_type = deserialize_type_proto_for_type(elem_type) - if nested_type is None: - raise ValueError(f"SequenceType must have elem_type set: {proto}") - return _core.SequenceType(nested_type, denotation=denotation) - if proto.HasField("optional_type"): - # FIXME(justinchuby): Allow nested types being None - if (elem_type := _get_field(proto.optional_type, "elem_type")) is None: - raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}") - nested_type = deserialize_type_proto_for_type(elem_type) - if nested_type is None: - raise ValueError(f"SequenceType must have elem_type set: {proto}") - return _core.OptionalType(nested_type, denotation=denotation) - if proto.HasField("map_type"): - # TODO(justinchuby): Do we need to support map types? - raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}") - - return None - - -@_capture_errors(str) -def deserialize_dimension( - proto: onnx.TensorShapeProto.Dimension, -) -> tuple[int | _core.SymbolicDim, str | None]: - """Deserialize a dimension proto into (dimension, denotation). - - Args: - proto: The dimension proto to deserialize. - - Returns: - A tuple of the dimension and its denotation. - """ - value_field = proto.WhichOneof("value") - denotation = _get_field(proto, "denotation") - if value_field is not None: - value = getattr(proto, value_field) - if value_field == "dim_value": - return value, denotation - if value_field == "dim_param": - return _core.SymbolicDim(value), denotation - return _core.SymbolicDim(None), denotation - - -@_capture_errors(lambda proto, base_path: proto.name) -def deserialize_tensor( - proto: onnx.TensorProto, base_path: str | os.PathLike = "" -) -> _protocols.TensorProtocol: - # TODO: Sanitize base_path - if proto.data_location == onnx.TensorProto.EXTERNAL: - external_info = onnx.external_data_helper.ExternalDataInfo(proto) - return _core.ExternalTensor( - external_info.location, - offset=external_info.offset, - length=external_info.length, - dtype=_enums.DataType(proto.data_type), - base_dir=base_path, - name=_get_field(proto, "name"), - shape=_core.Shape(proto.dims), - doc_string=_get_field(proto, "doc_string"), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - if proto.data_type == _enums.DataType.STRING: - name = _get_field(proto, "name") - doc_string = _get_field(proto, "doc_string") - metadata_props = deserialize_metadata_props(proto.metadata_props) - return _core.StringTensor( - proto.string_data, - shape=_core.Shape(proto.dims), - name=name, - doc_string=doc_string, - metadata_props=metadata_props, - ) - return TensorProtoTensor(proto) - - -def deserialize_metadata_props( - proto: Sequence[onnx.StringStringEntryProto], -) -> dict[str, str] | None: - if len(proto) == 0: - # Avoid creating an empty dictionary to save memory - return None - return {entry.key: entry.value for entry in proto} - - -_deserialize_string_string_maps = deserialize_metadata_props - - -def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr: - return _deserialize_attribute(proto, []) - - -@_capture_errors(lambda proto, scoped_values: str(proto)) -def _deserialize_attribute( - proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]] -) -> _core.Attr | _core.RefAttr: - name = proto.name - doc_string = _get_field(proto, "doc_string") - type_ = _enums.AttributeType(proto.type) - ref_attr_name = _get_field(proto, "ref_attr_name") - if ref_attr_name: - return _core.RefAttr(name, ref_attr_name, type_, doc_string=doc_string) - - if type_ == _enums.AttributeType.INT: - return _core.AttrInt64(name, proto.i, doc_string=doc_string) - if type_ == _enums.AttributeType.FLOAT: - return _core.AttrFloat32(name, proto.f, doc_string=doc_string) - if type_ == _enums.AttributeType.STRING: - return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string) - if type_ == _enums.AttributeType.INTS: - return _core.AttrInt64s(name, proto.ints, doc_string=doc_string) - if type_ == _enums.AttributeType.FLOATS: - return _core.AttrFloat32s(name, proto.floats, doc_string=doc_string) - if type_ == _enums.AttributeType.STRINGS: - return _core.AttrStrings( - name, [s.decode("utf-8") for s in proto.strings], doc_string=doc_string - ) - if type_ == _enums.AttributeType.TENSOR: - return _core.AttrTensor(name, deserialize_tensor(proto.t), doc_string=doc_string) - if type_ == _enums.AttributeType.GRAPH: - return _core.AttrGraph( - name, _deserialize_graph(proto.g, scoped_values), doc_string=doc_string - ) - if type_ == _enums.AttributeType.TENSORS: - return _core.AttrTensors( - name, - [deserialize_tensor(t) for t in proto.tensors], - doc_string=doc_string, - ) - if type_ == _enums.AttributeType.GRAPHS: - return _core.AttrGraphs( - name, - [_deserialize_graph(g, scoped_values) for g in proto.graphs], - doc_string=doc_string, - ) - if type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError( - f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" - ) - if type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError( - f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" - ) - if type_ == _enums.AttributeType.TYPE_PROTO: - ir_type = deserialize_type_proto_for_type(proto.tp) - shape = deserialize_type_proto_for_shape(proto.tp) - return _core.AttrTypeProto( - name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string - ) - if type_ == _enums.AttributeType.TYPE_PROTOS: - type_and_shapes = [] - for type_proto in proto.type_protos: - ir_type = deserialize_type_proto_for_type(type_proto) - shape = deserialize_type_proto_for_shape(type_proto) - type_and_shapes.append(_core.TypeAndShape(ir_type, shape)) - return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string) - if type_ == _enums.AttributeType.UNDEFINED: - return _core.Attr(name, type_, None, doc_string=doc_string) - raise ValueError(f"Unsupported attribute type: '{type_}'") - - -def deserialize_node(proto: onnx.NodeProto) -> _core.Node: - return _deserialize_node( - proto, scoped_values=[{}], value_info={}, quantization_annotations={} - ) - - -@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto)) -def _deserialize_node( - proto: onnx.NodeProto, - scoped_values: list[dict[str, _core.Value]], - value_info: dict[str, onnx.ValueInfoProto], - quantization_annotations: dict[str, onnx.TensorAnnotation], -) -> _core.Node: - node_inputs: list[_core.Value | None] = [] - for input_name in proto.input: - if input_name == "": - # Empty input - node_inputs.append(None) - continue - - # Find the input in all value scopes - found = False - for values in reversed(scoped_values): - if input_name not in values: - continue - node_inputs.append(values[input_name]) - found = True - del values # Remove the reference so it is not used by mistake - break - if not found: - # If the input is not found, we know the graph may be unsorted and - # the input may be a supposed-to-be initializer or an output of a node that comes later. - # Here we create the value with the name and add it to the current scope. - # Nodes need to check the value pool for potentially initialized outputs - logger.warning( - "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. " - "The graph may be unsorted. Creating a new input (current depth: %s) .", - input_name, - proto.name, - proto.domain, - proto.op_type, - getattr(proto, "overload", ""), - len(scoped_values), - ) - if len(scoped_values) > 1: - logger.warning( - "Caveat: The value is created in the subgraph. If " - "the node is referencing a value that is not in the current graph, " - "it is impossible to create it in the correct scope.", - ) - value = _core.Value(name=input_name) - # Fill in shape/type information if they exist - if input_name in value_info: - deserialize_value_info_proto(value_info[input_name], value) - if input_name in quantization_annotations: - _deserialize_quantization_annotation( - quantization_annotations[input_name], value - ) - node_inputs.append(value) - # We can only create the value in the current scope. If the subgraph is - # referencing a value that is not in the current scope, it is impossible - # to create it in the correct scope. - scoped_values[-1][input_name] = value - - # Build the output values for the node. - node_outputs: list[_core.Value] = [] - for output_name in proto.output: - if output_name == "": - # Empty output - node_outputs.append(_core.Value(name="")) - continue - - # 1. When the graph is unsorted, we may be able to find the output already created - # as an input to some other nodes in the current scope. - # Note that a value is always owned by the producing node. Even though a value - # can be created when parsing inputs of other nodes, the new node created here - # that produces the value will assume ownership. It is then impossible to transfer - # the ownership to any other node. - - # The output can only be found in the current scope. It is impossible for - # a node to produce an output that is not in its own scope. - current_scope = scoped_values[-1] - if output_name in current_scope: - value = current_scope[output_name] - else: - # 2. Common scenario: the graph is sorted and this is the first time we see the output. - # Create the value and add it to the current scope. - value = _core.Value(name=output_name) - current_scope[output_name] = value - # Fill in shape/type information if they exist - if output_name in value_info: - deserialize_value_info_proto(value_info[output_name], value) - else: - logger.debug( - "ValueInfoProto not found for output '%s' in node '%s' of type '%s'", - output_name, - proto.name, - proto.op_type, - ) - if output_name in quantization_annotations: - _deserialize_quantization_annotation(quantization_annotations[output_name], value) - node_outputs.append(value) - return _core.Node( - proto.domain, - proto.op_type, - node_inputs, - [_deserialize_attribute(a, scoped_values) for a in proto.attribute], - overload=getattr(proto, "overload", ""), - outputs=node_outputs, - name=proto.name, - doc_string=_get_field(proto, "doc_string"), - metadata_props=deserialize_metadata_props(proto.metadata_props), - ) - - -# Serialization - - -def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto: - return serialize_model_into(onnx.ModelProto(), from_=model) - - -@_capture_errors( - lambda model_proto, from_: ( - f"ir_version={from_.ir_version}, producer_name={from_.producer_name}, " - f"producer_version={from_.producer_version}, domain={from_.domain}, " - ) -) -def serialize_model_into( - model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol -) -> onnx.ModelProto: - """Serialize an IR model to an ONNX model proto.""" - model_proto.ir_version = from_.ir_version - if from_.producer_name: - model_proto.producer_name = from_.producer_name - if from_.producer_version: - model_proto.producer_version = from_.producer_version - if from_.domain: - model_proto.domain = from_.domain - if from_.model_version: - model_proto.model_version = from_.model_version - if from_.doc_string: - model_proto.doc_string = from_.doc_string - # Sort names for deterministic serialization - _serialize_opset_imports_into(model_proto.opset_import, from_.opset_imports) - if from_.metadata_props: - _serialize_metadata_props_into(model_proto.metadata_props, from_.metadata_props) - serialize_graph_into(model_proto.graph, from_.graph) - - create_value_info_in_functions = from_.ir_version >= _FUNCTION_VALUE_INFO_SUPPORTED_VERSION - for func in from_.functions.values(): - serialize_function_into( - model_proto.functions.add(), - from_=func, - create_value_info=create_value_info_in_functions, - ) - if not create_value_info_in_functions: - # Create them in the main graph instead - _serialize_experimental_value_info_for_function_ir9_into(model_proto.graph, func) - return model_proto - - -def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool: - """Check if value info should be created for a value. - - Args: - value: The value to check. - - Returns: - True if value info should be created for the value. - """ - # No need to serialize value info if it is not set - if value.shape is None and value.type is None: - return False - if not value.name: - logger.debug("Did not serialize '%s' because its name is empty", value) - return False - return True - - -def _serialize_experimental_value_info_for_function_ir9_into( - graph_proto: onnx.GraphProto, function: _protocols.FunctionProtocol -) -> None: - """Serialize value info for functions in an experimental format for IR version 9. - - Because IRv9 and older does not have ValueInfoProto for functions, we give the value info - special names and store them in the main graph instead. - - The experimental format is: - {function_domain}::{function_name}/{value_name} - - Args: - graph_proto: The graph proto to create ValueInfoProto in. - function: The function to serialize. - """ - # TODO(justinchuby): In the future, we can decide if it is a good idea to simply iterate over - # all values in the function and call serialize_value_into instead. - function_qualified_name = f"{function.domain}::{function.name}" - - def format_name(value_name: str) -> str: - return f"{function_qualified_name}/{value_name}" - - for input in function.inputs: - if not input.name: - logger.warning( - "Function '%s': Value name not set for function input: %s", - function_qualified_name, - input, - ) - continue - if not _should_create_value_info_for_value(input): - # No need to serialize value info if it is not set - continue - serialize_value_into(graph_proto.value_info.add(), input, name=format_name(input.name)) - for node in function: - for node_output in node.outputs: - if not node_output.name: - logger.warning( - "Function '%s': Value name not set for node output: %s", - function_qualified_name, - node_output, - ) - continue - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue - serialize_value_into( - graph_proto.value_info.add(), - node_output, - name=format_name(node_output.name), - ) - - -def _serialize_opset_imports_into( - opset_ids: proto_containers.RepeatedCompositeFieldContainer[onnx.OperatorSetIdProto], - from_: Mapping[str, int], -) -> None: - """Serialize opset imports into a repeated field of OperatorSetId protos. - - Args: - opset_ids: The repeated field to serialize into. - from_: The mapping of opset domains to versions to serialize. - """ - # Sort names for deterministic serialization - for domain, version in from_.items(): - opset_ids.add(domain=domain, version=version) - - -def _serialize_string_string_maps( - string_string_entries: proto_containers.RepeatedCompositeFieldContainer[ - onnx.StringStringEntryProto - ], - from_: Mapping[str, str], -) -> None: - """Serialize a mapping into a repeated field of string-string entries. - - Args: - string_string_entries: The repeated field to serialize into. - from_: The mapping of a mapping to serialize. - """ - # Sort names for deterministic serialization - for key in sorted(from_): - string_string_entries.add(key=key, value=from_[key]) - - -_serialize_metadata_props_into = _serialize_string_string_maps - - -def _maybe_add_quantization_annotation( - graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol -) -> None: - if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD): - _serialize_tensor_annotation_into( - graph_proto.quantization_annotation.add(), value.name, quantization_annotation - ) - - -def _serialize_tensor_annotation_into( - tensor_annotation_proto: onnx.TensorAnnotation, - tensor_name: str, - quant_parameter_tensor_names: dict[str, str], -) -> None: - tensor_annotation_proto.tensor_name = tensor_name - _serialize_string_string_maps( - tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names - ) - - -def serialize_graph( - graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol, -) -> onnx.GraphProto: - """Serializes the given graph into an :class:`onnx.GraphProto`. - - When the graph initializers do not have `const_value` set, they will be skipped. - - Args: - graph: The graph to be serialized. - - Returns: - The serialized ONNX GraphProto object. - """ - graph_proto = onnx.GraphProto() - serialize_graph_into(graph_proto, from_=graph) - return graph_proto - - -@_capture_errors( - lambda graph_proto, from_: ( - f"name={from_.name}, doc_string={from_.doc_string}, " - f"len(inputs)={len(from_.inputs)}, len(initializers)={len(from_.initializers)}, " - f"len(nodes)={len(from_)}, len(outputs)={len(from_.outputs)}, metadata_props={from_.metadata_props}" - ) -) -def serialize_graph_into( - graph_proto: onnx.GraphProto, - from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol, -) -> None: - if from_.name: - graph_proto.name = from_.name - if from_.doc_string: - graph_proto.doc_string = from_.doc_string - for input_ in from_.inputs: - serialize_value_into(graph_proto.input.add(), input_) - if input_.name not in from_.initializers: - # Annotations for initializers will be added below to avoid double adding - # TODO(justinchuby): We should add a method is_initializer() on Value when - # the initializer list is tracked - _maybe_add_quantization_annotation(graph_proto, input_) - input_names = {input_.name for input_ in from_.inputs} - # TODO(justinchuby): Support sparse_initializer - for value in from_.initializers.values(): - _maybe_add_quantization_annotation(graph_proto, value) - if _should_create_value_info_for_value(value) and value.name not in input_names: - # Serialize information about all initializers into value_info, - # except for those that are also graph inputs - serialize_value_into(graph_proto.value_info.add(), value) - if value.const_value is None: - # Skip initializers without constant values - logger.warning("Initializer '%s' does not have a constant value set.", value.name) - continue - # Make sure the tensor's name is the same as the value's name - value.const_value.name = value.name - serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value) - for node in from_: - serialize_node_into(graph_proto.node.add(), from_=node) - for node_output in node.outputs: - if node_output.is_graph_output(): - # No need to serialize info for these outputs because they are handled as graph outputs - continue - _maybe_add_quantization_annotation(graph_proto, node_output) - if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue - # No need to serialize value info if it is not set - continue - else: - serialize_value_into(graph_proto.value_info.add(), node_output) - for output in from_.outputs: - serialize_value_into(graph_proto.output.add(), from_=output) - _maybe_add_quantization_annotation(graph_proto, output) - if from_.metadata_props: - _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props) - - -def serialize_function( - function: _protocols.FunctionProtocol, *, create_value_info: bool = True -) -> onnx.FunctionProto: - """Serialize an IR function as a FunctionProto. - - Args: - function: The function to serialize. - create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported - starting from ONNX IR version 10. - """ - function_proto = onnx.FunctionProto() - serialize_function_into( - function_proto, from_=function, create_value_info=create_value_info - ) - return function_proto - - -@_capture_errors(lambda function_proto, from_, create_value_info: repr(from_)) -def serialize_function_into( - function_proto: onnx.FunctionProto, - from_: _protocols.FunctionProtocol, - *, - create_value_info: bool = True, -) -> None: - """Serialize an IR function into a FunctionProto. - - Args: - function_proto: The proto to serialize into. - from_: The function to serialize. - create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported - starting from ONNX IR version 10. - """ - if from_.domain: - function_proto.domain = from_.domain - if from_.name: - function_proto.name = from_.name - if from_.overload: - function_proto.overload = from_.overload - if from_.doc_string: - function_proto.doc_string = from_.doc_string - if from_.opset_imports: - # A valid ONNX graph should have at least one opset import, that is - # the default ONNX opset. - # Here we check for emptiness before serializing to keep the logic consistent - _serialize_opset_imports_into(function_proto.opset_import, from_.opset_imports) - if from_.metadata_props: - _serialize_metadata_props_into(function_proto.metadata_props, from_.metadata_props) - for input_ in from_.inputs: - function_proto.input.append(input_.name) - if not _should_create_value_info_for_value(input_): - # No need to serialize value info if it is not set - continue - if not create_value_info: - continue - serialize_value_into(function_proto.value_info.add(), input_) - for attr in from_.attributes.values(): - if attr.value is not None: - serialize_attribute_into(function_proto.attribute_proto.add(), from_=attr) - else: - # ONNX does not record type information if the attribute does not have a default - function_proto.attribute.append(attr.name) - for func_output in from_.outputs: - function_proto.output.append(func_output.name) - # No need to serialize value info for function outputs because they are - # also node outputs - for node in from_: - serialize_node_into(function_proto.node.add(), from_=node) - # Record value info for outputs - for node_output in node.outputs: - if not _should_create_value_info_for_value(node_output): - # No need to serialize value info if it is not set - continue - if not create_value_info: - continue - serialize_value_into(function_proto.value_info.add(), node_output) - - -def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto: - node_proto = onnx.NodeProto() - serialize_node_into(node_proto, from_=node) - return node_proto - - -def _remove_trailing_outputs( - outputs: Sequence[_protocols.ValueProtocol], -) -> Sequence[_protocols.ValueProtocol]: - """Remove trailing outputs that have empty names. - - Args: - outputs: The outputs to remove trailing outputs from. - - Returns: - The outputs with trailing outputs removed. - """ - for i, output in enumerate(reversed(outputs)): - if output.name: - return outputs[: len(outputs) - i] - return [] - - -@_capture_errors(lambda node_proto, from_: repr(from_)) -def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None: - node_proto.op_type = from_.op_type - if from_.domain: - # If the domain is "", we can assume the default domain and not set it - node_proto.domain = from_.domain - if from_.name: - node_proto.name = from_.name - if from_.overload: - node_proto.overload = from_.overload - if from_.doc_string: - node_proto.doc_string = from_.doc_string - if from_.metadata_props: - _serialize_metadata_props_into(node_proto.metadata_props, from_.metadata_props) - for input_ in from_.inputs: - if input_ is None: - node_proto.input.append("") - else: - node_proto.input.append(input_.name) - - # Do not include the trailing outputs that have empty names - for output in _remove_trailing_outputs(from_.outputs): - node_proto.output.append(output.name) - - for attr in from_.attributes.values(): - if isinstance(attr, _core.Attr): - serialize_attribute_into(node_proto.attribute.add(), from_=attr) - elif isinstance(attr, _core.RefAttr): - serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) - # Handle protocol attributes for completeness. We do not check them first because - # calling isinstance on a protocol can be slow. - # Most of the time, we will have Attr or RefAttr so the two branches below - # will not be taken. - elif isinstance(attr, _protocols.AttributeProtocol): - serialize_attribute_into(node_proto.attribute.add(), from_=attr) - elif isinstance(attr, _protocols.ReferenceAttributeProtocol): - serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) - else: - raise TypeError(f"Unsupported attribute type: {type(attr)}") - - -def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto: - tensor_proto = onnx.TensorProto() - serialize_tensor_into(tensor_proto, from_=tensor) - return tensor_proto - - -@_capture_errors(lambda tensor_proto, from_: repr(from_)) -def serialize_tensor_into( - tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol -) -> None: - if isinstance(from_, TensorProtoTensor): - # Directly copy from the tensor proto if it is available - tensor_proto.CopyFrom(from_.raw) - if from_.metadata_props: - _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props) - return - - if from_.name: - tensor_proto.name = from_.name - if from_.doc_string: - tensor_proto.doc_string = from_.doc_string - tensor_proto.data_type = from_.dtype.value - tensor_proto.dims.extend(from_.shape.numpy()) - if isinstance(from_, _core.ExternalTensor): - # Store external tensors as is - tensor_proto.data_location = onnx.TensorProto.EXTERNAL - for k, v in { - "location": os.fspath(from_.location), - "offset": from_.offset, - "length": from_.length, - }.items(): - if v is not None: - entry = tensor_proto.external_data.add() - entry.key = k - entry.value = str(v) - elif isinstance(from_, _core.StringTensor): - tensor_proto.string_data.extend(from_.string_data()) - else: - tensor_proto.raw_data = from_.tobytes() - _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props) - - -def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto: - attribute_proto = onnx.AttributeProto() - serialize_attribute_into(attribute_proto, from_=attribute) - return attribute_proto - - -@_capture_errors(lambda attribute_proto, from_: repr(from_)) -def serialize_attribute_into( - attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol -) -> None: - attribute_proto.name = from_.name - if from_.doc_string: - attribute_proto.doc_string = from_.doc_string - _fill_in_value_for_attribute(attribute_proto, from_.type, from_.value) - - -def _fill_in_value_for_attribute( - attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any -) -> None: - if type_ == _enums.AttributeType.INT: - # value: int - attribute_proto.i = value - attribute_proto.type = onnx.AttributeProto.INT - elif type_ == _enums.AttributeType.FLOAT: - # value: float - attribute_proto.f = value - attribute_proto.type = onnx.AttributeProto.FLOAT - elif type_ == _enums.AttributeType.STRING: - # value: str - attribute_proto.s = value.encode("utf-8") - attribute_proto.type = onnx.AttributeProto.STRING - elif type_ == _enums.AttributeType.INTS: - # value: Sequence[int] - attribute_proto.ints.extend(value) - attribute_proto.type = onnx.AttributeProto.INTS - elif type_ == _enums.AttributeType.FLOATS: - # value: Sequence[float] - attribute_proto.floats.extend(value) - attribute_proto.type = onnx.AttributeProto.FLOATS - elif type_ == _enums.AttributeType.STRINGS: - # value: Sequence[str] - attribute_proto.strings.extend([s.encode("utf-8") for s in value]) - attribute_proto.type = onnx.AttributeProto.STRINGS - elif type_ == _enums.AttributeType.TENSOR: - # value: _protocols.TensorProtocol - serialize_tensor_into(attribute_proto.t, value) - attribute_proto.type = onnx.AttributeProto.TENSOR - elif type_ == _enums.AttributeType.GRAPH: - # value: _protocols.GraphProtocol - serialize_graph_into(attribute_proto.g, value) - attribute_proto.type = onnx.AttributeProto.GRAPH - elif type_ == _enums.AttributeType.TENSORS: - # value: Sequence[_protocols.TensorProtocol] - for tensor in value: - serialize_tensor_into(attribute_proto.tensors.add(), tensor) - attribute_proto.type = onnx.AttributeProto.TENSORS - elif type_ == _enums.AttributeType.GRAPHS: - # value: Sequence[_protocols.GraphProtocol] - for graph in value: - serialize_graph_into(attribute_proto.graphs.add(), graph) - attribute_proto.type = onnx.AttributeProto.GRAPHS - elif type_ == _enums.AttributeType.SPARSE_TENSOR: - raise NotImplementedError( - f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" - ) - elif type_ == _enums.AttributeType.SPARSE_TENSORS: - raise NotImplementedError( - f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}" - ) - elif type_ == _enums.AttributeType.TYPE_PROTO: - # value: _core.TypeAndShape - if value.type is not None: - serialize_type_into(attribute_proto.tp, value.type) - # Need to create the type _before_ writing the shape - if value.shape is not None: - serialize_shape_into(attribute_proto.tp, value.shape) - attribute_proto.type = onnx.AttributeProto.TYPE_PROTO - elif type_ == _enums.AttributeType.TYPE_PROTOS: - for ir_type in value: - # ir_type: _core.TypeAndShape - type_proto = attribute_proto.type_protos.add() - if ir_type.type is not None: - serialize_type_into(type_proto, ir_type.type) - # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto - if ir_type.shape is not None: - serialize_shape_into(type_proto, ir_type.shape) - attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS - else: - raise TypeError(f"Unsupported attribute type: {type_}") - - -@_capture_errors(lambda attribute_proto, from_: repr(from_)) -def serialize_reference_attribute_into( - attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol -) -> None: - attribute_proto.name = from_.name - attribute_proto.ref_attr_name = from_.ref_attr_name - if from_.doc_string: - attribute_proto.doc_string = from_.doc_string - attribute_proto.type = typing.cast(onnx.AttributeProto.AttributeType, from_.type.value) - - -def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.ValueInfoProto: - """Serialize a value into a ValueInfoProto. - - Args: - value: The proto to serialize into. - from_: The value to serialize. - name: A custom name to set for the value info. If not provided, the name from the value will be used. - """ - value_info_proto = onnx.ValueInfoProto() - serialize_value_into(value_info_proto, value, name=name) - return value_info_proto - - -@_capture_errors(lambda value_info_proto, from_: repr(from_)) -def serialize_value_into( - value_info_proto: onnx.ValueInfoProto, - from_: _protocols.ValueProtocol, - *, - name: str = "", -) -> None: - """Serialize a value into a ValueInfoProto. - - Args: - value_info_proto: The proto to serialize into. - from_: The value to serialize. - name: A custom name to set for the value info. If not provided, the name from the value will be used. - """ - if name: - value_info_proto.name = name - else: - value_info_proto.name = from_.name - if from_.metadata_props: - _serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props) - if from_.type is not None: - serialize_type_into(value_info_proto.type, from_.type) - # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto - if from_.shape is not None: - serialize_shape_into(value_info_proto.type, from_.shape) - if from_.doc_string: - value_info_proto.doc_string = from_.doc_string - - -@_capture_errors(lambda type_proto, from_: repr(from_)) -def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None: - if from_.denotation: - type_proto.denotation = from_.denotation - if isinstance(from_, _core.TensorType): - tensor_type_proto = type_proto.tensor_type - tensor_type_proto.elem_type = from_.dtype.value - elif isinstance(from_, _core.SparseTensorType): - sparse_tensor_type_proto = type_proto.sparse_tensor_type - sparse_tensor_type_proto.elem_type = from_.dtype.value - elif isinstance(from_, _core.SequenceType): - sequence_type_proto = type_proto.sequence_type - serialize_type_into(sequence_type_proto.elem_type, from_.elem_type) - elif isinstance(from_, _core.OptionalType): - optional_type_proto = type_proto.optional_type - serialize_type_into(optional_type_proto.elem_type, from_.elem_type) - else: - raise TypeError(f"Unsupported type: {from_}") - - -def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto: - type_proto = onnx.TypeProto() - serialize_type_into(type_proto, from_=type_protocol) - return type_proto - - -@_capture_errors(lambda type_proto, from_: repr(from_)) -def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None: - value_field = type_proto.WhichOneof("value") - tensor_type = getattr(type_proto, value_field) - while not isinstance(tensor_type.elem_type, int): - # Find the leaf type that has the shape field - type_proto = tensor_type.elem_type - value_field = type_proto.WhichOneof("value") - tensor_type = getattr(type_proto, value_field) - # When from is empty, we still need to set the shape field to an empty list by touching it - tensor_type.shape.ClearField("dim") - for i, dim in enumerate(from_): - denotation = from_.get_denotation(i) - serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation) - - -@_capture_errors(lambda dim_proto, dim, denotation: repr(dim_proto)) -def serialize_dimension_into( - dim_proto: onnx.TensorShapeProto.Dimension, - dim: int | _protocols.SymbolicDimProtocol, - denotation: str | None = None, -) -> None: - if denotation: - dim_proto.denotation = denotation - if isinstance(dim, int): - dim_proto.dim_value = dim - elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)): - if dim.value is not None: - # TODO(justinchuby): None is probably not a valid value for dim_param - dim_proto.dim_param = str(dim.value) diff --git a/onnxscript/ir/serde_test.py b/onnxscript/ir/serde_test.py deleted file mode 100644 index 303f02761f..0000000000 --- a/onnxscript/ir/serde_test.py +++ /dev/null @@ -1,417 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import unittest - -import google.protobuf.text_format -import ml_dtypes -import numpy as np -import onnx -import parameterized - -from onnxscript import ir -from onnxscript._internal import version_utils -from onnxscript.ir import serde - - -class ConvenienceFunctionsTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("model", onnx.ModelProto()), - ("graph", onnx.GraphProto()), - ("node", onnx.NodeProto(input=["X"], output=["Y"])), - ( - "tensor", - onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0]), - ), - ("value_info", onnx.ValueInfoProto()), - ("type", onnx.TypeProto()), - ("attribute", onnx.AttributeProto()), - ] - ) - def test_from_proto(self, _: str, proto): - serde.from_proto(proto) - - @parameterized.parameterized.expand( - [ - ("model", ir.Model(ir.Graph([], [], nodes=[]), ir_version=1)), - ("graph", ir.Graph([], [], nodes=[])), - ( - "node", - ir.Node("", "Op", inputs=[], outputs=[ir.Value(name="value")]), - ), - ( - "tensor", - serde.TensorProtoTensor( - onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [1], [1.0]) - ), - ), - ("value", ir.Value(name="value")), - ("type", ir.SequenceType(ir.OptionalType(ir.TensorType(ir.DataType.COMPLEX128)))), - ("attribute", ir.Attr("attribute", ir.AttributeType.FLOAT, 1)), - ("ref_attribute", ir.RefAttr("ref_attr", "attr", ir.AttributeType.FLOAT)), - ("graph_view", ir.GraphView([], [], nodes=[])), - ] - ) - def test_to_proto(self, _: str, ir_object): - serde.to_proto(ir_object) - - -class TensorProtoTensorTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ("FLOAT", onnx.TensorProto.FLOAT), - ("BOOL", onnx.TensorProto.BOOL), - ("FLOAT16", onnx.TensorProto.FLOAT16), - ("DOUBLE", onnx.TensorProto.DOUBLE), - ] - ) - def test_tensor_proto_tensor(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor( - "test_tensor", dtype, [1, 9], [-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0] - ) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - if dtype == onnx.TensorProto.BOOL and version_utils.numpy_older_than("1.25"): - self.skipTest("numpy<1.25 does not support bool dtype in from_dlpack") - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - @unittest.skipIf( - version_utils.onnx_older_than("1.17"), - "numpy_helper.to_array was not correctly implemented in onnx<1.17", - ) - def test_tensor_proto_tensor_bfloat16(self): - expected_array = np.array( - [[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]], dtype=ml_dtypes.bfloat16 - ) - tensor_proto = onnx.helper.make_tensor( - "test_tensor", - onnx.TensorProto.BFLOAT16, - [1, 9], - np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 42.0, 2.0]]), - ) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal( - array_from_raw_data.view(ml_dtypes.bfloat16), expected_array - ) - # Test dlpack - with self.assertRaises(BufferError): - # NumPy does not support bfloat16 in from_dlpack - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - @parameterized.parameterized.expand( - [ - ( - "FLOAT8E4M3FN", - onnx.TensorProto.FLOAT8E4M3FN, - ml_dtypes.float8_e4m3fn, - ), - ( - "FLOAT8E4M3FNUZ", - onnx.TensorProto.FLOAT8E4M3FNUZ, - ml_dtypes.float8_e4m3fnuz, - ), - ( - "FLOAT8E5M2", - onnx.TensorProto.FLOAT8E5M2, - ml_dtypes.float8_e5m2, - ), - ( - "FLOAT8E5M2FNUZ", - onnx.TensorProto.FLOAT8E5M2FNUZ, - ml_dtypes.float8_e5m2fnuz, - ), - ] - ) - def test_tensor_proto_tensor_float8(self, _: str, dtype: int, np_dtype): - expected_array = np.array([[-3.0, -1.0, -0.5, -0.0, +0.0, 0.5, 1.0, 40.0, 2.0]]) - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 9], expected_array) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal( - tensor.numpy().view(np_dtype).astype(np.float32), expected_array - ) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = ( - serde.TensorProtoTensor(tensor_proto_from_raw_data) - .numpy() - .view(np_dtype) - .astype(np.float32) - ) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - with self.assertRaises(BufferError): - # DL Pack does not support float8 - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - @parameterized.parameterized.expand( - [ - ("INT8", onnx.TensorProto.INT8), - ("INT16", onnx.TensorProto.INT16), - ("INT32", onnx.TensorProto.INT32), - ("INT64", onnx.TensorProto.INT64), - ("INT4", onnx.TensorProto.INT4), - ] - ) - def test_tensor_proto_tensor_int(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 4], [-1, 0, 1, 8]) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array( - tensor_proto - ) # [-1, 0, 1, 7], 8 is clamped to 7 - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - if dtype == onnx.TensorProto.INT4: - return # DL Pack does not support int4 - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - @parameterized.parameterized.expand( - [ - ("UINT8", onnx.TensorProto.UINT8), - ("UINT16", onnx.TensorProto.UINT16), - ("UINT32", onnx.TensorProto.UINT32), - ("UINT64", onnx.TensorProto.UINT64), - ("UINT4", onnx.TensorProto.UINT4), - ] - ) - def test_tensor_proto_tensor_uint(self, _: str, dtype: int): - tensor_proto = onnx.helper.make_tensor("test_tensor", dtype, [1, 3], [0, 1, 8]) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - if dtype == onnx.TensorProto.UINT4: - return # DL Pack does not support uint4 - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - @parameterized.parameterized.expand( - [ - ("COMPLEX64", onnx.TensorProto.COMPLEX64, np.complex64), - ("COMPLEX128", onnx.TensorProto.COMPLEX128, np.complex128), - ] - ) - def test_tensor_proto_tensor_complex(self, _: str, dtype: int, np_dtype: np.dtype): - expected_array = np.array([[0.0 + 1j, 0.2 - 1j, 0.3]], dtype=np_dtype) - tensor_proto = onnx.helper.make_tensor( - "test_tensor", dtype, [1, 3], [0.0 + 1j, 0.2 - 1j, 0.3] - ) - tensor = serde.TensorProtoTensor(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - def test_tensor_proto_tensor_empty_tensor(self): - tensor_proto = onnx.helper.make_tensor("test_tensor", onnx.TensorProto.FLOAT, [0], []) - tensor = serde.TensorProtoTensor(tensor_proto) - expected_array = onnx.numpy_helper.to_array(tensor_proto) - np.testing.assert_array_equal(tensor.numpy(), expected_array) - raw_data = tensor.tobytes() - tensor_proto_from_raw_data = onnx.TensorProto( - dims=tensor_proto.dims, - data_type=tensor_proto.data_type, - raw_data=raw_data, - ) - array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) - np.testing.assert_array_equal(array_from_raw_data, expected_array) - # Test dlpack - np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) - - -class DeserializeGraphTest(unittest.TestCase): - def test_deserialize_graph_handles_unsorted_graph(self): - node_0 = ir.Node( - "", - "Op_0", - inputs=[ir.Input("input_0"), ir.Input("input_1")], - num_outputs=2, - name="node_0", - ) - node_1 = ir.Node( - "", - "Op_1", - inputs=[node_0.outputs[0]], - num_outputs=1, - name="node_1", - ) - graph = ir.Graph( - inputs=node_0.inputs, # type: ignore - outputs=[node_1.outputs[0]], - # Unsorted nodes - nodes=[node_1, node_0], - name="test_graph", - ) - graph_proto = serde.serialize_graph(graph) - deserialized_graph = serde.deserialize_graph(graph_proto) - self.assertEqual(deserialized_graph[0].op_type, "Op_1") - self.assertEqual(deserialized_graph[1].op_type, "Op_0") - - def test_deserialize_graph_handles_invalid_output(self): - # The graph has an output that is not connected to any node, and it does not - # have shape/type information. - graph_with_invalid_output = ir.Graph( - inputs=[], - outputs=[ir.Value(name="invalid_output")], - nodes=[], - name="graph_with_invalid_output", - ) - graph_proto = serde.serialize_graph(graph_with_invalid_output) - deserialized_graph = serde.deserialize_graph(graph_proto) - self.assertEqual(len(deserialized_graph.outputs), 1) - self.assertEqual(deserialized_graph.outputs[0].name, "invalid_output") - self.assertEqual(deserialized_graph.outputs[0].type, None) - self.assertEqual(deserialized_graph.outputs[0].shape, None) - self.assertEqual(deserialized_graph.outputs[0].dtype, None) - - -class QuantizationAnnotationTest(unittest.TestCase): - """Test that quantization annotations are correctly serialized and deserialized.""" - - def setUp(self): - model_text = """\ -ir_version: 8 -producer_name: "pytorch" -producer_version: "2.1.1" -graph { - input { - name: "input" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - } - } - } - } - output { - name: "output" - type { - tensor_type { - elem_type: 1 - shape { - dim { - dim_value: 1 - } - } - } - } - } - node { - input: "input" - output: "intermediate_value" - op_type: "TestOp1" - domain: "test_domain" - } - node { - input: "intermediate_value" - output: "output" - op_type: "TestOp2" - domain: "test_domain" - } - quantization_annotation { - tensor_name: "input" - quant_parameter_tensor_names { - key: "custom_key" - value: "arbitrary_value_input" - } - } - quantization_annotation { - tensor_name: "intermediate_value" - quant_parameter_tensor_names { - key: "custom_key" - value: "arbitrary_value_intermediate" - } - } - quantization_annotation { - tensor_name: "output" - quant_parameter_tensor_names { - key: "custom_key" - value: "arbitrary_value_output" - } - } -}""" - self.model = onnx.ModelProto() - google.protobuf.text_format.Parse(model_text, self.model) - - def test_deserialize_quantization_annotation(self): - model = serde.deserialize_model(self.model) - self.assertEqual( - model.graph.inputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_input"}, - ) - self.assertEqual( - model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_intermediate"}, - ) - self.assertEqual( - model.graph.outputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_output"}, - ) - - def test_serde_roundtrip(self): - model = serde.deserialize_model(self.model) - serialized_model = serde.serialize_model(model) - deserialized_model = serde.deserialize_model(serialized_model) - self.assertEqual( - deserialized_model.graph.inputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_input"}, - ) - self.assertEqual( - deserialized_model.graph.node(0).outputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_intermediate"}, - ) - self.assertEqual( - deserialized_model.graph.outputs[0].meta["quant_parameter_tensor_names"], - {"custom_key": "arbitrary_value_output"}, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/tape.py b/onnxscript/ir/tape.py deleted file mode 100644 index 9270dcdcec..0000000000 --- a/onnxscript/ir/tape.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Taping module to facilitate building IR graphs.""" - -# NOTE: Be *selective* about what this module exports because it is part of the public API. - -from __future__ import annotations - -__all__ = [ - "Tape", -] - -from onnxscript.ir._tape import Tape - -Tape.__module__ = __name__ diff --git a/onnxscript/ir/tensor_adapters.py b/onnxscript/ir/tensor_adapters.py deleted file mode 100644 index 0a74e0a74c..0000000000 --- a/onnxscript/ir/tensor_adapters.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Compatible adapters implementing the TensorProtocol interface for various framework tensor types. - -This module provides public classes that implement the :class:`onnxscript.ir.TensorProtocol` -interface for various tensor types from popular deep learning frameworks. - -You can use these classes to create tensors and use them in the IR graph like any other tensor. - -Example:: - import torch - from onnxscript import ir - - # Create a PyTorch tensor - torch_tensor = torch.tensor([1, 2, 3]) - - # Wrap the PyTorch tensor in a TorchTensor object - ir_tensor = ir.tensor_adapters.TorchTensor(torch_tensor) - - # Use the IR tensor in the graph - attr = ir.AttrTensor("x", ir_tensor) - print(attr) -""" - -# pylint: disable=import-outside-toplevel - -# NOTE: DO NOT import any framework-specific modules here in the global namespace. - -from __future__ import annotations - -__all__ = [ - "TorchTensor", -] - -import ctypes -from typing import TYPE_CHECKING, Any - -import numpy.typing as npt - -from onnxscript import ir -from onnxscript.ir import _core - -if TYPE_CHECKING: - import torch - - -class TorchTensor(_core.Tensor): - def __init__( - self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None - ): - # Pass the tensor as the raw data to ir.Tensor's constructor - import torch - - _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, - torch.uint16: ir.DataType.UINT16, - torch.uint32: ir.DataType.UINT32, - torch.uint64: ir.DataType.UINT64, - } - super().__init__( - tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string - ) - - def numpy(self) -> npt.NDArray: - import torch - - self.raw: torch.Tensor - if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) - if self.dtype in { - ir.DataType.FLOAT8E4M3FN, - ir.DataType.FLOAT8E4M3FNUZ, - ir.DataType.FLOAT8E5M2, - ir.DataType.FLOAT8E5M2FNUZ, - }: - return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) - - return self.raw.numpy(force=True) - - def __array__(self, dtype: Any = None, copy: bool | None = None) -> npt.NDArray: - del copy # Unused, but needed for the signature - if dtype is None: - return self.numpy() - return self.numpy().__array__(dtype) - - def tobytes(self) -> bytes: - # Implement tobytes to support native PyTorch types so we can use types like bloat16 - # Reading from memory directly is also more efficient because - # it avoids copying to a NumPy array - import torch._subclasses.fake_tensor - - with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access - # Disable any fake mode so calling detach() etc. will return a real tensor - tensor = self.raw.detach().cpu().contiguous() - - if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access - raise TypeError( - f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " - "with a tensor backed by real data using ONNXProgram.apply_weights() " - "or save the model without initializers by setting include_initializers=False." - ) - - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) - ) diff --git a/onnxscript/ir/tensor_adapters_test.py b/onnxscript/ir/tensor_adapters_test.py deleted file mode 100644 index 4898cb42a4..0000000000 --- a/onnxscript/ir/tensor_adapters_test.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Unit tests for the tensor_adapters module.""" - -from __future__ import annotations - -import importlib.util -import unittest - -import ml_dtypes -import numpy as np -import parameterized -import torch - -from onnxscript.ir import tensor_adapters - - -def skip_if_no(module_name: str): - """Decorator to skip a test if a module is not installed.""" - if importlib.util.find_spec(module_name) is None: - return unittest.skip(f"{module_name} not installed") - return lambda func: func - - -@skip_if_no("torch") -class TorchTensorTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - (torch.bfloat16, ml_dtypes.bfloat16), - (torch.bool, np.bool_), - (torch.complex128, np.complex128), - (torch.complex64, np.complex64), - (torch.float16, np.float16), - (torch.float32, np.float32), - (torch.float64, np.float64), - (torch.float8_e4m3fn, ml_dtypes.float8_e4m3fn), - (torch.float8_e4m3fnuz, ml_dtypes.float8_e4m3fnuz), - (torch.float8_e5m2, ml_dtypes.float8_e5m2), - (torch.float8_e5m2fnuz, ml_dtypes.float8_e5m2fnuz), - (torch.int16, np.int16), - (torch.int32, np.int32), - (torch.int64, np.int64), - (torch.int8, np.int8), - (torch.uint16, np.uint16), - (torch.uint32, np.uint32), - (torch.uint64, np.uint64), - (torch.uint8, np.uint8), - ], - ) - def test_numpy_returns_correct_dtype(self, dtype: torch.dtype, np_dtype): - tensor = tensor_adapters.TorchTensor(torch.tensor([1], dtype=dtype)) - self.assertEqual(tensor.numpy().dtype, np_dtype) - self.assertEqual(tensor.__array__().dtype, np_dtype) - self.assertEqual(np.array(tensor).dtype, np_dtype) - - @parameterized.parameterized.expand( - [ - (torch.bfloat16,), - (torch.bool,), - (torch.complex128,), - (torch.complex64,), - (torch.float16,), - (torch.float32,), - (torch.float64,), - (torch.float8_e4m3fn,), - (torch.float8_e4m3fnuz,), - (torch.float8_e5m2,), - (torch.float8_e5m2fnuz,), - (torch.int16,), - (torch.int32,), - (torch.int64,), - (torch.int8,), - (torch.uint16,), - (torch.uint32,), - (torch.uint64,), - (torch.uint8,), - ], - ) - def test_tobytes(self, dtype: torch.dtype): - tensor = tensor_adapters.TorchTensor(torch.tensor([1], dtype=dtype)) - self.assertEqual(tensor.tobytes(), tensor.numpy().tobytes()) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py deleted file mode 100644 index 5fa9a9acf7..0000000000 --- a/onnxscript/ir/traversal.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Utilities for traversing the IR graph.""" - -from __future__ import annotations - -__all__ = [ - "RecursiveGraphIterator", -] - -from typing import Callable, Iterator, Reversible, Union - -from typing_extensions import Self - -from onnxscript.ir import _core, _enums - -GraphLike = Union[_core.Graph, _core.Function, _core.GraphView] - - -class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]): - def __init__( - self, - graph_like: GraphLike, - *, - recursive: Callable[[_core.Node], bool] | None = None, - reverse: bool = False, - ): - """Iterate over the nodes in the graph, recursively visiting subgraphs. - - Args: - graph_like: The graph to traverse. - recursive: A callback that determines whether to recursively visit the subgraphs - contained in a node. If not provided, all nodes in subgraphs are visited. - reverse: Whether to iterate in reverse order. - """ - self._graph = graph_like - self._recursive = recursive - self._reverse = reverse - self._iterator = self._recursive_node_iter(graph_like) - - def __iter__(self) -> Self: - self._iterator = self._recursive_node_iter(self._graph) - return self - - def __next__(self) -> _core.Node: - return next(self._iterator) - - def _recursive_node_iter( - self, graph: _core.Graph | _core.Function | _core.GraphView - ) -> Iterator[_core.Node]: - iterable = reversed(graph) if self._reverse else graph - for node in iterable: # type: ignore[union-attr] - yield node - if self._recursive is not None and not self._recursive(node): - continue - yield from self._iterate_subgraphs(node) - - def _iterate_subgraphs(self, node: _core.Node): - for attr in node.attributes.values(): - if not isinstance(attr, _core.Attr): - continue - if attr.type == _enums.AttributeType.GRAPH: - yield from RecursiveGraphIterator( - attr.value, - recursive=self._recursive, - reverse=self._reverse, - ) - elif attr.type == _enums.AttributeType.GRAPHS: - graphs = reversed(attr.value) if self._reverse else attr.value - for graph in graphs: - yield from RecursiveGraphIterator( - graph, - recursive=self._recursive, - reverse=self._reverse, - ) - - def __reversed__(self) -> Iterator[_core.Node]: - return RecursiveGraphIterator( - self._graph, - recursive=self._recursive, - reverse=not self._reverse, - ) diff --git a/onnxscript/ir/traversal_test.py b/onnxscript/ir/traversal_test.py deleted file mode 100644 index 5ed4d31473..0000000000 --- a/onnxscript/ir/traversal_test.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import parameterized - -from onnxscript import ir -from onnxscript.ir import traversal - - -class RecursiveGraphIteratorTest(unittest.TestCase): - def setUp(self): - self.graph = ir.Graph( - [], - [], - nodes=[ - ir.Node("", "Node1", []), - ir.Node("", "Node2", []), - ir.Node( - "", - "If", - [], - attributes=[ - ir.AttrGraph( - "then_branch", - ir.Graph( - [], - [], - nodes=[ir.Node("", "Node3", []), ir.Node("", "Node4", [])], - name="then_graph", - ), - ), - ir.AttrGraph( - "else_branch", - ir.Graph( - [], - [], - nodes=[ir.Node("", "Node5", []), ir.Node("", "Node6", [])], - name="else_graph", - ), - ), - ], - ), - ], - name="main_graph", - ) - - @parameterized.parameterized.expand( - [ - ("forward", False, ("Node1", "Node2", "If", "Node3", "Node4", "Node5", "Node6")), - ("reversed", True, ("If", "Node4", "Node3", "Node6", "Node5", "Node2", "Node1")), - ] - ) - def test_recursive_graph_iterator(self, _: str, reverse: bool, expected: tuple[str, ...]): - iterator = traversal.RecursiveGraphIterator(self.graph) - if reverse: - iterator = reversed(iterator) - nodes = list(iterator) - self.assertEqual(tuple(node.op_type for node in nodes), expected) - - @parameterized.parameterized.expand( - [ - ("forward", False, ("Node1", "Node2", "If")), - ("reversed", True, ("If", "Node2", "Node1")), - ] - ) - def test_recursive_graph_iterator_recursive_controls_recursive_behavior( - self, _: str, reverse: bool, expected: list[str] - ): - nodes = list( - traversal.RecursiveGraphIterator( - self.graph, recursive=lambda node: node.op_type != "If", reverse=reverse - ) - ) - self.assertEqual(tuple(node.op_type for node in nodes), expected) - - -if __name__ == "__main__": - unittest.main() From 655aa4c742dddc71d8c7bce6f433f00e64b14bf4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:27:02 -0700 Subject: [PATCH 02/16] Update --- onnxscript/ir/_convenience/__init__.py | 378 ------------------ onnxscript/ir/_convenience/_constructors.py | 213 ---------- .../ir/_convenience/_constructors_test.py | 31 -- onnxscript/ir/_tape.py | 49 +++ onnxscript/ir/_tape_test.py | 76 ++++ onnxscript/ir/convenience.py | 1 + onnxscript/ir/passes/common/_c_api_utils.py | 77 ++++ onnxscript/optimizer/_constant_folding.py | 2 +- onnxscript/rewriter/__init__.py | 8 +- onnxscript/rewriter/_fusion_utils.py | 4 +- onnxscript/rewriter/_rewrite_rule.py | 4 +- onnxscript/rewriter/ort_fusions/_core.py | 4 +- .../rewriter/ort_fusions/attention_test.py | 4 +- onnxscript/rewriter/ort_fusions/mha_test.py | 6 +- tests/ir/public_api_test.py | 187 --------- 15 files changed, 219 insertions(+), 825 deletions(-) delete mode 100644 onnxscript/ir/_convenience/__init__.py delete mode 100644 onnxscript/ir/_convenience/_constructors.py delete mode 100644 onnxscript/ir/_convenience/_constructors_test.py create mode 100644 onnxscript/ir/_tape.py create mode 100644 onnxscript/ir/_tape_test.py create mode 100644 onnxscript/ir/convenience.py create mode 100644 onnxscript/ir/passes/common/_c_api_utils.py delete mode 100644 tests/ir/public_api_test.py diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py deleted file mode 100644 index 839c5d330b..0000000000 --- a/onnxscript/ir/_convenience/__init__.py +++ /dev/null @@ -1,378 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Convenience methods for constructing and manipulating the IR. - -This is an internal only module. We should choose to expose some of the methods -in convenience.py after they are proven to be useful. -""" - -from __future__ import annotations - -__all__ = [ - "convert_attribute", - "convert_attributes", - "replace_all_uses_with", - "create_value_mapping", - "replace_nodes_and_values", -] - -from typing import Mapping, Sequence, Union - -import onnx - -from onnxscript.ir import _core, _enums, _protocols, serde - -SupportedAttrTypes = Union[ - str, - int, - float, - Sequence[int], - Sequence[float], - Sequence[str], - _protocols.TensorProtocol, # This includes all in-memory tensor types - onnx.TensorProto, - _core.Attr, - _core.RefAttr, - _protocols.GraphProtocol, - Sequence[_protocols.GraphProtocol], - onnx.GraphProto, - _protocols.TypeProtocol, - Sequence[_protocols.TypeProtocol], - None, -] - - -def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType: - """Infer the attribute type based on the type of the Python object.""" - if isinstance(attr, int): - return _enums.AttributeType.INT - if isinstance(attr, float): - return _enums.AttributeType.FLOAT - if isinstance(attr, str): - return _enums.AttributeType.STRING - if isinstance(attr, (_core.Attr, _core.RefAttr)): - return attr.type - if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr): - return _enums.AttributeType.INTS - if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr): - return _enums.AttributeType.FLOATS - if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr): - return _enums.AttributeType.STRINGS - if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)): - # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower - return _enums.AttributeType.TENSOR - if isinstance(attr, Sequence) and all( - isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)) - for x in attr - ): - return _enums.AttributeType.TENSORS - if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)): - return _enums.AttributeType.GRAPH - if isinstance(attr, Sequence) and all( - isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr - ): - return _enums.AttributeType.GRAPHS - if isinstance( - attr, - (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol), - ): - return _enums.AttributeType.TYPE_PROTO - if isinstance(attr, Sequence) and all( - isinstance( - x, - ( - _core.TensorType, - _core.SequenceType, - _core.OptionalType, - _protocols.TypeProtocol, - ), - ) - for x in attr - ): - return _enums.AttributeType.TYPE_PROTOS - raise TypeError(f"Unsupported attribute type: '{type(attr)}'") - - -def convert_attribute( - name: str, - attr: SupportedAttrTypes, - attr_type: _enums.AttributeType | None = None, -) -> _core.Attr | _core.RefAttr: - """Convert a Python object to a _core.Attr object. - - This method is useful when constructing nodes with attributes. It infers the - attribute type based on the type of the Python value. - - Args: - name: The name of the attribute. - attr: The value of the attribute. - attr_type: The type of the attribute. This is required when attr is None. - When provided, it overrides the inferred type. - - Returns: - A ``Attr`` object. - - Raises: - ValueError: If ``attr`` is ``None`` and ``attr_type`` is not provided. - TypeError: If the type of the attribute is not supported. - """ - if attr is None: - if attr_type is None: - raise ValueError("attr_type must be provided when attr is None") - return _core.Attr(name, attr_type, None) - - if isinstance(attr, (_core.Attr, _core.RefAttr)): - if attr.name != name: - raise ValueError( - f"Attribute name '{attr.name}' does not match provided name '{name}'" - ) - if attr_type is not None and attr.type != attr_type: - raise ValueError( - f"Attribute type '{attr.type}' does not match provided type '{attr_type}'" - ) - return attr - - if attr_type is None: - attr_type = _infer_attribute_type(attr) - - if attr_type == _enums.AttributeType.INT: - return _core.AttrInt64(name, attr) # type: ignore - if attr_type == _enums.AttributeType.FLOAT: - return _core.AttrFloat32(name, attr) # type: ignore - if attr_type == _enums.AttributeType.STRING: - return _core.AttrString(name, attr) # type: ignore - if attr_type == _enums.AttributeType.INTS: - return _core.AttrInt64s(name, attr) # type: ignore - if attr_type == _enums.AttributeType.FLOATS: - return _core.AttrFloat32s(name, attr) # type: ignore - if attr_type == _enums.AttributeType.STRINGS: - return _core.AttrStrings(name, attr) # type: ignore - if attr_type == _enums.AttributeType.TENSOR: - if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)): - return _core.AttrTensor(name, attr) - if isinstance(attr, onnx.TensorProto): - return _core.AttrTensor(name, serde.deserialize_tensor(attr)) - if attr_type == _enums.AttributeType.TENSORS: - tensors = [] - for t in attr: # type: ignore[union-attr] - if isinstance(t, onnx.TensorProto): - tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t))) - else: - tensors.append(t) # type: ignore[arg-type] - return _core.AttrTensors(name, tensors) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.GRAPH: - if isinstance(attr, onnx.GraphProto): - attr = serde.deserialize_graph(attr) - return _core.AttrGraph(name, attr) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.GRAPHS: - graphs = [] - for graph in attr: # type: ignore[union-attr] - if isinstance(graph, onnx.GraphProto): - graphs.append(serde.deserialize_graph(graph)) - else: - graphs.append(graph) # type: ignore[arg-type] - return _core.AttrGraphs(name, graphs) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.TYPE_PROTO: - return _core.AttrTypeProto(name, attr) # type: ignore[arg-type] - if attr_type == _enums.AttributeType.TYPE_PROTOS: - return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type] - raise TypeError(f"Unsupported attribute type: '{type(attr)}'") - - -def convert_attributes( - attrs: Mapping[str, SupportedAttrTypes], -) -> list[_core.Attr | _core.RefAttr]: - """Convert a dictionary of attributes to a list of _core.Attr objects. - - It infers the attribute type based on the type of the value. The supported - types are: int, float, str, Sequence[int], Sequence[float], Sequence[str], - :class:`_core.Tensor`, and :class:`_core.Attr`:: - - >>> from onnxscript import ir - >>> import onnx - >>> import numpy as np - >>> attrs = { - ... "int": 1, - ... "float": 1.0, - ... "str": "hello", - ... "ints": [1, 2, 3], - ... "floats": [1.0, 2.0, 3.0], - ... "strings": ["hello", "world"], - ... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])), - ... "tensor_proto": - ... onnx.TensorProto( - ... dims=[3], - ... data_type=onnx.TensorProto.FLOAT, - ... float_data=[1.0, 2.0, 3.0], - ... name="proto", - ... ), - ... "graph": ir.Graph([], [], nodes=[], name="graph0"), - ... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")], - ... "type_proto": ir.TensorType(ir.DataType.FLOAT), - ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)], - ... } - >>> convert_attributes(attrs) - [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph( - name='graph0', - inputs=( - - ), - outputs=( - - ), - len()=0 - )), Attr('graphs', GRAPHS, [Graph( - name='graph1', - inputs=( - - ), - outputs=( - - ), - len()=0 - ), Graph( - name='graph2', - inputs=( - - ), - outputs=( - - ), - len()=0 - )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])] - - Args: - attrs: A dictionary of {: } to convert. - - Returns: - A list of _core.Attr objects. - """ - attributes: list[_core.Attr | _core.RefAttr] = [] - for name, attr in attrs.items(): - if attr is not None: - attributes.append(convert_attribute(name, attr)) - return attributes - - -def replace_all_uses_with( - values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], - replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol], -) -> None: - """Replace all uses of the given values with the replacements. - - This is useful when nodes in the graph are replaced with new nodes, where - the old users need to be updated to use the outputs of the new nodes. - - For example, suppose we have the following graph:: - - A -> {B, C} - - We want to replace the node A with a new node D:: - - >>> from onnxscript import ir - >>> input = ir.Input("input") - >>> node_a = ir.Node("", "A", [input]) - >>> node_b = ir.Node("", "B", node_a.outputs) - >>> node_c = ir.Node("", "C", node_a.outputs) - >>> node_d = ir.Node("", "D", [input]) - >>> replace_all_uses_with(node_a.outputs, node_d.outputs) - >>> len(node_b.inputs) - 1 - >>> node_b.inputs[0].producer().op_type - 'D' - >>> len(node_c.inputs) - 1 - >>> node_c.inputs[0].producer().op_type - 'D' - >>> len(node_a.outputs[0].uses()) - 0 - - When values and replacements are sequences, they are zipped into pairs. All - users of the first value is replaced with the first replacement, and so on. - - .. note:: - You still need to update the graph outputs if any of the values being - replaced are part of the graph outputs. Be sure to remove the old nodes - from the graph using ``graph.remove()`` if they are no longer needed. - - Args: - values: The value or values to be replaced. - replacements: The new value or values to use as inputs. - """ - if not isinstance(values, Sequence): - values = (values,) - if not isinstance(replacements, Sequence): - replacements = (replacements,) - if len(values) != len(replacements): - raise ValueError("The number of values and replacements must match.") - for value, replacement in zip(values, replacements): - for user_node, index in tuple(value.uses()): - user_node.replace_input_with(index, replacement) - - -def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: - """Return a dictionary mapping names to values in the graph. - - The mapping does not include values from subgraphs. - - Args: - graph: The graph to extract the mapping from. - - Returns: - A dictionary mapping names to values. - """ - values: dict[str, _core.Value] = {} - values.update(graph.initializers) - # The names of the values can be None or "", which we need to exclude - for input in graph.inputs: - if not input.name: - continue - values[input.name] = input - for node in graph: - for value in node.outputs: - if not value.name: - continue - values[value.name] = value - return values - - -def replace_nodes_and_values( - graph_or_function: _core.Graph | _core.Function, - /, - insertion_point: _core.Node, - old_nodes: Sequence[_core.Node], - new_nodes: Sequence[_core.Node], - old_values: Sequence[_core.Value], - new_values: Sequence[_core.Value], -) -> None: - """Replaces nodes and values in the graph or function. - - Args: - graph_or_function: The graph or function to replace nodes and values in. - insertion_point: The node to insert the new nodes after. - old_nodes: The nodes to replace. - new_nodes: The nodes to replace with. - old_values: The values to replace. - new_values: The values to replace with. - """ - - for old_value, new_value in zip(old_values, new_values): - # Propagate relevant info from old value to new value - # TODO(Rama): Perhaps this should be a separate utility function. Also, consider - # merging old and new type/shape info. - new_value.type = old_value.type - new_value.shape = old_value.shape - new_value.const_value = old_value.const_value - new_value.name = old_value.name - - # Reconnect the users of the deleted values to use the new values - replace_all_uses_with(old_values, new_values) - # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] - - # insert new nodes after the index node - graph_or_function.insert_after(insertion_point, new_nodes) - graph_or_function.remove(old_nodes, safe=True) diff --git a/onnxscript/ir/_convenience/_constructors.py b/onnxscript/ir/_convenience/_constructors.py deleted file mode 100644 index 33b738e569..0000000000 --- a/onnxscript/ir/_convenience/_constructors.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Convenience constructors for IR objects.""" - -from __future__ import annotations - -__all__ = [ - "tensor", - "node", -] - -import typing -from typing import Mapping, Sequence - -import numpy as np -import onnx - -from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters - -if typing.TYPE_CHECKING: - import numpy.typing as npt - - from onnxscript import ir - - -def tensor( - value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible, - dtype: _enums.DataType | None = None, - name: str | None = None, - doc_string: str | None = None, -) -> _protocols.TensorProtocol: - """Create a tensor value from an ArrayLike object or a TensorProto. - - The dtype must match the value. Reinterpretation of the value is - not supported, unless if the value is a plain Python object, in which case - it is converted to a numpy array with the given dtype. - - ``value`` can be a numpy array, a plain Python object, or a TensorProto. - - Example:: - - >>> from onnxscript import ir - >>> import numpy as np - >>> import ml_dtypes - >>> import onnx - >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16)) - Tensor(array([1, 2, 3], dtype=int16), name=None) - >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16) - Tensor(array([1, 2, 3], dtype=bfloat16), name=None) - >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5])) - >>> tp_tensor.numpy() - array(0.5, dtype=float32) - >>> import torch - >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor") - TorchTensor(tensor([1., 2.]), name='torch_tensor') - - Args: - value: The numpy array to create the tensor from. - dtype: The data type of the tensor. - name: The name of the tensor. - doc_string: The documentation string of the tensor. - - Returns: - A tensor value. - - Raises: - ValueError: If the dtype does not match the value when value is not a plain Python - object like ``list[int]``. - """ - if isinstance(value, _protocols.TensorProtocol): - if dtype is not None and dtype != value.dtype: - raise ValueError( - f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. " - "You do not have to specify the dtype when value is a Tensor." - ) - return value - if isinstance(value, onnx.TensorProto): - tensor_ = serde.deserialize_tensor(value) - if name is not None: - tensor_.name = name - if doc_string is not None: - tensor_.doc_string = doc_string - if dtype is not None and dtype != tensor_.dtype: - raise ValueError( - f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}" - "You do not have to specify the dtype when value is a TensorProto." - ) - return tensor_ - elif str(type(value)) == "": - # NOTE: We use str(type(...)) and do not import torch for type checking - # as it creates overhead during import - return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type] - elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): - return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string) - - # Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor - if dtype is not None: - if not isinstance(dtype, _enums.DataType): - raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}") - numpy_dtype = dtype.numpy() - elif isinstance(value, Sequence) and not value: - raise ValueError("dtype must be specified when value is an empty sequence.") - elif isinstance(value, int) and not isinstance(value, bool): - # Specify int64 for ints because on Windows this may be int32 - numpy_dtype = np.dtype(np.int64) - elif isinstance(value, float): - # If the value is a single float, we use np.float32 as the default dtype - numpy_dtype = np.dtype(np.float32) - elif isinstance(value, Sequence) and value: - if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value): - numpy_dtype = np.dtype(np.int64) - elif all(isinstance(elem, float) for elem in value): - # If the value is a sequence of floats, we use np.float32 as the default dtype - numpy_dtype = np.dtype(np.float32) - else: - numpy_dtype = None - else: - numpy_dtype = None - - array = np.array(value, dtype=numpy_dtype) - - # Handle string tensors by encoding them - if isinstance(value, str) or ( - isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value) - ): - array = np.strings.encode(array, encoding="utf-8") - return _core.StringTensor( - array, - shape=_core.Shape(array.shape), - name=name, - doc_string=doc_string, - ) - - return _core.Tensor( - array, - dtype=dtype, - shape=_core.Shape(array.shape), - name=name, - doc_string=doc_string, - ) - - -def node( - op_type: str, - inputs: Sequence[ir.Value | None], - attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, - *, - domain: str = "", - overload: str = "", - num_outputs: int | None = None, - outputs: Sequence[ir.Value] | None = None, - version: int | None = None, - graph: ir.Graph | None = None, - name: str | None = None, - doc_string: str | None = None, - metadata_props: dict[str, str] | None = None, -) -> ir.Node: - """Create an :class:`ir.Node`. - - This is a convenience constructor for creating a Node that supports Python - objects as attributes. - - Example:: - - >>> from onnxscript import ir - >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) - >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) - >>> node = ir.node( - ... "SomeOp", - ... inputs=[input_a, input_b], - ... attributes={"alpha": 1.0, "some_list": [1, 2, 3]}, - ... domain="some.domain", - ... name="node_name" - ... ) - >>> node.op_type - 'SomeOp' - - Args: - op_type: The name of the operator. - inputs: The input values. When an input is None, it is an empty input. - attributes: The attributes. RefAttr can be used only when the node is defined in a Function. - overload: The overload name when the node is invoking a function. - domain: The domain of the operator. For onnx operators, this is an empty string. - num_outputs: The number of outputs of the node. If not specified, the number is 1. - outputs: The output values. If None, the outputs are created during initialization. - version: The version of the operator. If None, the version is unspecified and will follow that of the graph. - graph: The graph that the node belongs to. If None, the node is not added to any graph. - A `Node` must belong to zero or one graph. - name: The name of the node. If None, the node is anonymous. - doc_string: The documentation string. - metadata_props: The metadata properties. - - Returns: - A node with the given op_type and inputs. - """ - if attributes is None: - attrs: Sequence[ir.Attr | ir.RefAttr] = () - else: - attrs = _convenience.convert_attributes(attributes) - return _core.Node( - domain=domain, - op_type=op_type, - inputs=inputs, - attributes=attrs, - overload=overload, - num_outputs=num_outputs, - outputs=outputs, - version=version, - graph=graph, - name=name, - doc_string=doc_string, - metadata_props=metadata_props, - ) diff --git a/onnxscript/ir/_convenience/_constructors_test.py b/onnxscript/ir/_convenience/_constructors_test.py deleted file mode 100644 index 6f291d8175..0000000000 --- a/onnxscript/ir/_convenience/_constructors_test.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Unit tests for the _constructors module.""" - -import unittest - -import numpy as np - -from onnxscript import ir -from onnxscript.ir._convenience import _constructors - - -class ConstructorsTest(unittest.TestCase): - def test_tensor_accepts_torch_tensor(self): - import torch as some_random_name # pylint: disable=import-outside-toplevel - - torch_tensor = some_random_name.tensor([1, 2, 3]) - tensor = _constructors.tensor(torch_tensor) - np.testing.assert_array_equal(tensor, torch_tensor.numpy()) - - def test_tensor_raises_value_error_for_empty_sequence_without_dtype(self): - with self.assertRaises(ValueError): - _constructors.tensor([]) - - def test_tensor_handles_empty_sequence_with_dtype(self): - tensor = _constructors.tensor([], dtype=ir.DataType.FLOAT) - np.testing.assert_array_equal(tensor.numpy(), np.array([], dtype=np.float32)) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py new file mode 100644 index 0000000000..8626983199 --- /dev/null +++ b/onnxscript/ir/_tape.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Convenience methods for constructing the IR.""" + +from __future__ import annotations + +from typing import ( + Any, + Sequence, +) + +from onnx_ir import tape + + +class Builder(tape.Tape): + """An extension of the tape that provides a more convenient API for constructing the IR.""" + + def __getattr__(self, op_type: str) -> Any: + return lambda *args, **kwargs: self._make_node(op_type, args, kwargs) + + def _make_node(self, op_type: str, inputs: Sequence[ir.Value], kwargs: dict[str, Any]): + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) + if isinstance(outputs, Sequence): + num_outputs = len(outputs) + else: + assert isinstance(outputs, int) + num_outputs = outputs + + if num_outputs == 1: + value = super().op( + op_type, inputs=inputs, attributes=kwargs, domain=domain, version=version + ) + if isinstance(outputs, Sequence): + value.name = outputs[0] + return value + values = super().op_multi_out( + op_type, + inputs=inputs, + attributes=kwargs, + domain=domain, + version=version, + num_outputs=num_outputs, + ) + if isinstance(outputs, Sequence): + for value, name in zip(values, outputs): + value.name = name + return values diff --git a/onnxscript/ir/_tape_test.py b/onnxscript/ir/_tape_test.py new file mode 100644 index 0000000000..46cbcc23fe --- /dev/null +++ b/onnxscript/ir/_tape_test.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from onnxscript import ir + + +class TestTape(unittest.TestCase): + def test_op(self): + # Create a simple ONNX model with shape inference + # Define the model + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + _ = tape.op("Add", inputs=inputs) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add"]) + + def test_initializers(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + # Shape and type are not explicitly set for the initializer but it should still work + initializer = tape.initializer( + ir.tensor([[2, 3]], dtype=ir.DataType.FLOAT), name="initializer" + ) + val_add = tape.op("Add", inputs=inputs) + _ = tape.op("Mul", inputs=[val_add, initializer]) + + self.assertEqual([n.op_type for n in tape.nodes], ["Add", "Mul"]) + self.assertEqual(tape.initializers, (initializer,)) + + def test_op_multi_out(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape((2, 1)), + const_value=ir.tensor([[42]] * 2, dtype=ir.DataType.FLOAT), + ), + ] + + tape = ir.tape.Tape() + + out1, out2, out3 = tape.op_multi_out("SomeOp", inputs=inputs, num_outputs=3) # pylint: disable=unbalanced-tuple-unpacking + _ = tape.op("SomeOtherOp", inputs=[out1, out2, out3]) + + self.assertEqual([n.op_type for n in tape.nodes], ["SomeOp", "SomeOtherOp"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py new file mode 100644 index 0000000000..a0fff386b1 --- /dev/null +++ b/onnxscript/ir/convenience.py @@ -0,0 +1 @@ +from onnx_ir.convenience import * # type: ignore # noqa: F403 diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py new file mode 100644 index 0000000000..bb2715c75c --- /dev/null +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Utilities for interfacing with onnx C APIs.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Callable, TypeVar + +from onnxscript import ir + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) +# Temporarily remove initializers larger than this size to keep model size down +# for the onnx.shape_inference call because it needs to serialize the model +_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +_R = TypeVar("_R") + + +def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: + """Call an ONNX C API function by temporarily removing initializers. + + This is necessary because the ONNX C API does not support large models + with initializers that have large tensor values. The input model is left + unchanged no matter the call succeeds or not. + + Args: + func: Partially applied function that takes a model proto and returns anything. + model: The IR model to pass to the API function. + + Returns: + The resulting ModelProto that contains the result of the API call. + """ + + # Store the original initializer values so they can be restored + initializer_values = tuple(model.graph.initializers.values()) + tensors = {v.name: v.const_value for v in initializer_values} + original_inputs_len = len(model.graph.inputs) + + # Turn the initializers into inputs and clear the initializers + # to limit the model size + for initializer in initializer_values: + # Make sure the initializer has its shape/type set + assert initializer.const_value is not None + if initializer.shape is None: + initializer.shape = initializer.const_value.shape # type: ignore[assignment] + if initializer.dtype is None: + initializer.dtype = initializer.const_value.dtype + if initializer not in model.graph.inputs: + model.graph.inputs.append(initializer) + if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: + # Temporarily remove the initializer value to reduce model size + # for onnx.shape_inference + initializer.const_value = None + assert initializer.name is not None + model.graph.initializers.pop(initializer.name) + + proto = ir.serde.serialize_model(model) + + try: + # Call the ONNX C API function + result = func(proto) + finally: + # Restore the original initializer values so the model is unchanged + for initializer in initializer_values: + initializer.const_value = tensors[initializer.name] + model.graph.register_initializer(initializer) + + # Restore the original inputs + inputs = model.graph.inputs[:original_inputs_len] + model.graph.inputs.clear() + model.graph.inputs.extend(inputs) + + return result diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 7505770fb5..b321003b83 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -16,7 +16,7 @@ import onnx.reference.ops import onnxscript.ir as ir -import onnxscript.ir._tape as _tape +from onnxscript.ir import _tape import onnxscript.utils.utils as utils DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024 diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b0..c2613acf55 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -13,7 +13,7 @@ import onnx from onnxscript import ir -from onnxscript.ir.passes.common import unused_removal +import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter import ( broadcast_to_matmul, cast_constant_of_shape, @@ -90,9 +90,9 @@ def rewrite( rewrite_pass = ir.passes.PassManager( ( RewritePass(pattern_rewrite_rules), - unused_removal.RemoveUnusedNodesPass(), - unused_removal.RemoveUnusedFunctionsPass(), - unused_removal.RemoveUnusedOpsetsPass(), + common_passes.RemoveUnusedNodesPass(), + common_passes.RemoveUnusedFunctionsPass(), + common_passes.RemoveUnusedOpsetsPass(), ) ) model_ir = rewrite_pass(model_ir).model diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 59bdf87bd0..c8051f8199 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -5,7 +5,7 @@ from typing import Callable, Sequence, Union import onnxscript.ir as ir -from onnxscript.ir.passes.common import shape_inference +import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter import pattern Dim = Union[int, ir.SymbolicDim] @@ -38,7 +38,7 @@ def apply_to( ) -> int: count = rules.apply_to_model(model) if apply_shape_inference: - shape_inference.infer_shapes(model) + common_passes.ShapeInferencePass()(model) if count == 0 and debug: tracer = pattern.MatchingTracer() rules.apply_to_model(model, tracer=tracer) diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index f22374b753..303ee6d3c7 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,7 +18,7 @@ import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir from onnxscript import ir -from onnxscript.ir import _convenience, _tape +from onnxscript.ir import convenience, _tape T = TypeVar("T") @@ -525,7 +525,7 @@ def _apply_to_graph_or_function( ) f = ir.Function(domain, name, overload, graph=graph, attributes=()) model.functions[f.identifier()] = f - _convenience.replace_nodes_and_values( + convenience.replace_nodes_and_values( graph_or_function, node, delta.match.nodes if rule.remove_nodes else [], diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 070f6313ab..b2eec760e0 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -3,7 +3,7 @@ from __future__ import annotations import onnxscript.ir as ir -from onnxscript.ir.passes.common import shape_inference +import onnxscript.ir.passes.common as common_passes from onnxscript.optimizer import optimize from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import ( @@ -49,7 +49,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model: # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet # incorporated in our optimizer. - shape_inference.infer_shapes(model) + common_passes.ShapeInferencePass()(model) optimize(model) return model diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index aaedc3fc0a..b77f1c2a8e 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -14,7 +14,7 @@ import onnxscript.rewriter.ort_fusions._core as xformers from onnxscript import FLOAT, script from onnxscript import opset18 as op -from onnxscript.ir.passes.common import shape_inference +import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test @@ -141,7 +141,7 @@ def test_model_with_mha(self, name, with_past): """Test the model with or without past inputs.""" inputs = self.random_inputs(with_past=with_past) model = self.create_model(with_past=with_past) - model = shape_inference.infer_shapes(model) + model = common_passes.ShapeInferencePass()(model).model test_with_ort = packaging.version.Version("1.20") <= ORT_VERSION if test_with_ort: diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 8f4ed9715e..5e68ebc4c3 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -8,7 +8,7 @@ import onnxscript.optimizer import onnxscript.rewriter.ort_fusions._core as xformers -from onnxscript.ir.passes.common import shape_inference +import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.models._smollm_2 import smollm_test_2 from onnxscript.rewriter.ort_fusions.models._whisper_decoder import whisper_decoder_test @@ -58,7 +58,7 @@ def test_whisper_encoder(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - model = shape_inference.infer_shapes(model) + model = common_passes.ShapeInferencePass()(model).model mha_count = xformers.fuse_mha(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) @@ -83,7 +83,7 @@ def test_whisper_decoder(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - model = shape_inference.infer_shapes(model) + model = common_passes.ShapeInferencePass()(model).model mha_count = xformers.fuse_mha(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py deleted file mode 100644 index ac2655cf43..0000000000 --- a/tests/ir/public_api_test.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# Adapted from -# https://github.com/pytorch/pytorch/blob/b505e8647547f029d0f7df408ee5f2968f757f89/test/test_public_bindings.py#L523 -# Original code PyTorch license https://github.com/pytorch/pytorch/blob/main/LICENSE -# Modifications Copyright (c) Microsoft Corporation. All rights reserved. -from __future__ import annotations - -import importlib -import itertools -import os -import pathlib -import pkgutil -import unittest -from typing import Iterable - -import onnxscript.ir - -IR_NAMESPACE = "onnxscript.ir" - - -def _find_all_importables(pkg): - """Find all importables in the project. - Return them in order. - """ - return sorted( - set( - itertools.chain.from_iterable( - _discover_path_importables(pathlib.Path(p), pkg.__name__) for p in pkg.__path__ - ), - ), - ) - - -def _discover_path_importables(pkg_path: os.PathLike, pkg_name: str) -> Iterable[str]: - """Yield all importables under a given path and package. - This is like pkgutil.walk_packages, but does *not* skip over namespace - packages. Taken from https://stackoverflow.com/questions/41203765/init-py-required-for-pkgutil-walk-packages-in-python3 - """ - for dir_path, _, file_names in os.walk(pkg_path): - pkg_dir_path = pathlib.Path(dir_path) - - if pkg_dir_path.parts[-1] == "__pycache__": - continue - - if all(pathlib.Path(_).suffix != ".py" for _ in file_names): - continue - - rel_pt = pkg_dir_path.relative_to(pkg_path) - pkg_pref = ".".join((pkg_name, *rel_pt.parts)) - yield from ( - pkg_path - for _, pkg_path, _ in pkgutil.walk_packages( - (str(pkg_dir_path),), - prefix=f"{pkg_pref}.", - ) - ) - - -def _is_mod_public(modname: str) -> bool: - split_strs = modname.split(".") - return all(not (elem.startswith("_") or "_test" in elem) for elem in split_strs) - - -def _validate_module(modname: str, failure_list: list[str]) -> None: - mod = importlib.import_module(modname) - if not _is_mod_public(modname): - return - - # verifies that each public API has the correct module name and naming semantics - def check_one_element(elem, modname, mod, *, is_public, is_all): - obj = getattr(mod, elem) - elem_module = getattr(obj, "__module__", None) - # Only used for nice error message below - why_not_looks_public = "" - if elem_module is None: - why_not_looks_public = "because it does not have a `__module__` attribute" - elem_modname_starts_with_mod = ( - elem_module is not None - and elem_module.startswith(IR_NAMESPACE) - and "._" not in elem_module - ) - if not why_not_looks_public and not elem_modname_starts_with_mod: - why_not_looks_public = ( - f"because its `__module__` attribute (`{elem_module}`) is not within the " - f"onnxscript.ir library or does not start with the submodule where it is defined (`{modname}`)" - ) - # elem's name must NOT begin with an `_` and it's module name - # SHOULD start with it's current module since it's a public API - looks_public = not elem.startswith("_") and elem_modname_starts_with_mod - if not why_not_looks_public and not looks_public: - why_not_looks_public = f"because it starts with `_` (`{elem}`)" - - if is_public != looks_public: - if is_public: - why_is_public = ( - f"it is inside the module's (`{modname}`) `__all__`" - if is_all - else "it is an attribute that does not start with `_` on a module that " - "does not have `__all__` defined" - ) - fix_is_public = ( - f"remove it from the modules's (`{modname}`) `__all__`" - if is_all - else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name" - ) - else: - assert is_all - why_is_public = f"it is not inside the module's (`{modname}`) `__all__`" - fix_is_public = f"add it from the modules's (`{modname}`) `__all__`" - - if looks_public: - why_looks_public = ( - "it does look public because it follows the rules from the doc above " - "(does not start with `_` and has a proper `__module__`)." - ) - fix_looks_public = "make its name start with `_`" - else: - why_looks_public = why_not_looks_public - if not elem_modname_starts_with_mod: - fix_looks_public = ( - "make sure the `__module__` is properly set and points to a submodule " - f"of `{modname}`" - ) - else: - fix_looks_public = "remove the `_` at the beginning of the name" - - failure_list.append(f"# {modname}.{elem}:") - is_public_str = "" if is_public else " NOT" - failure_list.append(f" - Is{is_public_str} public: {why_is_public}") - looks_public_str = "" if looks_public else " NOT" - failure_list.append(f" - Does{looks_public_str} look public: {why_looks_public}") - # Swap the str below to avoid having to create the NOT again - failure_list.append( - " - You can do either of these two things to fix this problem:" - ) - failure_list.append(f" - To make it{looks_public_str} public: {fix_is_public}") - failure_list.append( - f" - To make it{is_public_str} look public: {fix_looks_public}" - ) - - if hasattr(mod, "__all__"): - public_api = mod.__all__ - all_api = dir(mod) - for elem in all_api: - check_one_element(elem, modname, mod, is_public=elem in public_api, is_all=True) - else: - all_api = dir(mod) - for elem in all_api: - if not elem.startswith("_"): - check_one_element(elem, modname, mod, is_public=True, is_all=False) - - -class TestPublicApiNamespace(unittest.TestCase): - tested_modules = (IR_NAMESPACE, *(_find_all_importables(onnxscript.ir))) - - def test_correct_module_names(self): - """ - An API is considered public, if its `__module__` starts with `onnxscript.ir` - and there is no name in `__module__` or the object itself that starts with "_". - Each public package should either: - - (preferred) Define `__all__` and all callables and classes in there must have their - `__module__` start with the current submodule's path. Things not in `__all__` should - NOT have their `__module__` start with the current submodule. - - (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their - `__module__` that start with the current submodule. - """ - failure_list = [] - - for modname in self.tested_modules: - _validate_module(modname, failure_list) - - msg = ( - "Make sure that everything that is public is expected (in particular that the module " - "has a properly populated `__all__` attribute) and that everything that is supposed to be public " - "does look public (it does not start with `_` and has a `__module__` that is properly populated)." - ) - - msg += "\n\nFull list:\n" - msg += "\n".join(failure_list) - - # empty lists are considered false in python - self.assertTrue(not failure_list, msg) - - -if __name__ == "__main__": - unittest.main() From e924b518997818dfdfb4175b57790cf0409b8c61 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:30:38 -0700 Subject: [PATCH 03/16] Docs --- docs/ir/getting_started.ipynb | 386 ----------------------------- docs/ir/index.md | 22 +- docs/ir/ir_api/core.md | 65 ----- docs/ir/ir_api/index.md | 13 - docs/ir/ir_api/ir_convenience.md | 15 -- docs/ir/ir_api/ir_external_data.md | 20 -- docs/ir/ir_api/ir_passes.md | 39 --- docs/ir/ir_api/ir_passes_common.md | 12 - docs/ir/ir_api/ir_tape.md | 18 -- docs/ir/ir_api/ir_traversal.md | 13 - docs/ir/tensors.md | 330 ------------------------ 11 files changed, 2 insertions(+), 931 deletions(-) delete mode 100644 docs/ir/getting_started.ipynb delete mode 100644 docs/ir/ir_api/core.md delete mode 100644 docs/ir/ir_api/index.md delete mode 100644 docs/ir/ir_api/ir_convenience.md delete mode 100644 docs/ir/ir_api/ir_external_data.md delete mode 100644 docs/ir/ir_api/ir_passes.md delete mode 100644 docs/ir/ir_api/ir_passes_common.md delete mode 100644 docs/ir/ir_api/ir_tape.md delete mode 100644 docs/ir/ir_api/ir_traversal.md delete mode 100644 docs/ir/tensors.md diff --git a/docs/ir/getting_started.ipynb b/docs/ir/getting_started.ipynb deleted file mode 100644 index 68e1faaa74..0000000000 --- a/docs/ir/getting_started.ipynb +++ /dev/null @@ -1,386 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "da6e9cca-6893-4273-a558-3dc18d49615e", - "metadata": {}, - "source": [ - "# Getting started with ONNX IR 🌱\n", - "The ONNX IR ships with the ONNX Script package and is available as `onnxscript.ir`.\n", - "To create an IR object from ONNX file, load it as `ModelProto` and call\n", - "`ir.from_proto()`:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# Define an example model for this example\n", - "MODEL_TEXT = r\"\"\"\n", - "<\n", - " ir_version: 8,\n", - " opset_import: [\"\" : 18],\n", - " producer_name: \"pytorch\",\n", - " producer_version: \"2.0.0\"\n", - ">\n", - "torch_jit (float[5,5,5] input_0) => (float[5,5] val_19, float[5,5] val_6) {\n", - " val_1 = Constant ()\n", - " val_2 = Shape (val_1)\n", - " val_3 = Size (val_2)\n", - " val_4 = Constant ()\n", - " val_5 = Equal (val_3, val_4)\n", - " val_6 = ReduceMean (input_0, val_1)\n", - " val_7 = ReduceMean (input_0, val_1)\n", - " val_8 = Shape (input_0)\n", - " val_9 = Gather (val_8, val_1)\n", - " val_10 = ReduceProd (val_9)\n", - " val_11 = Sub (input_0, val_7)\n", - " val_12 = Mul (val_11, val_11)\n", - " val_13 = ReduceMean (val_12, val_1)\n", - " val_14 = Cast (val_10)\n", - " val_15 = Mul (val_13, val_14)\n", - " val_16 = Constant ()\n", - " val_17 = Sub (val_10, val_16)\n", - " val_18 = Cast (val_17)\n", - " val_19 = Div (val_15, val_18)\n", - "}\n", - "\"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "cb5e7520-1aba-491b-b3e9-7d013e42d4ff", - "metadata": {}, - "outputs": [], - "source": [ - "import onnx\n", - "\n", - "from onnxscript import ir\n", - "\n", - "# Load the model as onnx.ModelProto\n", - "# You can also load the model from a file using onnx.load(\"model.onnx\")\n", - "model_proto = onnx.parser.parse_model(MODEL_TEXT)\n", - "\n", - "# Create an IR object from the model\n", - "model = ir.from_proto(model_proto)" - ] - }, - { - "cell_type": "markdown", - "id": "8f02f283-93c3-4e8f-b8f4-275f360ace61", - "metadata": {}, - "source": [ - "Now we can explore the IR object" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "969233d0-5e7a-4554-b4bc-ea06f448dd98", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The main graph has 19 nodes.\n" - ] - } - ], - "source": [ - "print(f\"The main graph has {len(model.graph)} nodes.\")" - ] - }, - { - "cell_type": "markdown", - "id": "0422514a-72d3-40a0-9734-c58911ddefc9", - "metadata": {}, - "source": [ - "All inputs" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "7b5689d8-dd2e-468f-9a87-653e97be7cf9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Input('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None)]\n" - ] - } - ], - "source": [ - "print(model.graph.inputs)" - ] - }, - { - "cell_type": "markdown", - "id": "d299db39-08f9-4646-856d-74e9cb18ee8a", - "metadata": {}, - "source": [ - "All outputs" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e3fb01aa-2ca5-4839-80c4-2c2d1b916a1c", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Value('val_19', type=Tensor(FLOAT), shape=[5,5], producer=, index=0), Value('val_6', type=Tensor(FLOAT), shape=[5,5], producer=, index=0)]\n" - ] - } - ], - "source": [ - "print(model.graph.outputs)" - ] - }, - { - "cell_type": "markdown", - "id": "1c52c8a2-52b4-40f3-996a-d44488e62623", - "metadata": {}, - "source": [ - "Nodes that uses the first input" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "c4894e97-7a8f-4f61-86dd-dd44aced02ed", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[(Node(name='', domain='', op_type='ReduceMean', inputs=(Input('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_1', type=None, shape=None, producer=, index=0)), attributes=OrderedDict([('keepdims', AttrInt64('keepdims', 0)), ('noop_with_empty_axes', AttrInt64('noop_with_empty_axes', 0))]), overload='', outputs=(Value('val_6', type=Tensor(FLOAT), shape=[5,5], producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='ReduceMean', inputs=(Input('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_1', type=None, shape=None, producer=, index=0)), attributes=OrderedDict([('keepdims', AttrInt64('keepdims', 1)), ('noop_with_empty_axes', AttrInt64('noop_with_empty_axes', 0))]), overload='', outputs=(Value('val_7', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='Shape', inputs=(Input('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None),), attributes=OrderedDict([('start', AttrInt64('start', 0))]), overload='', outputs=(Value('val_8', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0), (Node(name='', domain='', op_type='Sub', inputs=(Input('input_0', type=Tensor(FLOAT), shape=[5,5,5], producer=None, index=None), Value('val_7', type=None, shape=None, producer=, index=0)), attributes=OrderedDict(), overload='', outputs=(Value('val_11', type=None, shape=None, producer=, index=0),), version=None, doc_string=None), 0)]\n" - ] - } - ], - "source": [ - "print(list(model.graph.inputs[0].uses()))" - ] - }, - { - "cell_type": "markdown", - "id": "36d935b0-1910-4e7b-a2d8-57f6fa129670", - "metadata": {}, - "source": [ - "The node that produces the last output (as the i-th output)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "ac16cc49-9c82-4d5e-9c77-f0fd6260929b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "%\"val_6\" ⬅️ ::ReduceMean(%\"input_0\", %\"val_1\") {keepdims=0, noop_with_empty_axes=0}\n", - "0\n" - ] - } - ], - "source": [ - "print(model.graph.outputs[-1].producer())\n", - "print(model.graph.outputs[-1].index())" - ] - }, - { - "cell_type": "markdown", - "id": "d70a097f-da71-4299-bbc4-63ad3cc7be67", - "metadata": {}, - "source": [ - "Print the graph" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "772e831d-8d9d-4446-81ed-e119e8f2c0d6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
graph(\n",
-       "    name=torch_jit,\n",
-       "    inputs=(\n",
-       "        %\"input_0\"<FLOAT,[5,5,5]>\n",
-       "    ),\n",
-       "    outputs=(\n",
-       "        %\"val_19\"<FLOAT,[5,5]>,\n",
-       "        %\"val_6\"<FLOAT,[5,5]>\n",
-       "    ),\n",
-       ") {\n",
-       "     0 |  # :anonymous_node:128897555281104\n",
-       "          %\"val_1\"<?,?> ⬅️ ::Constant() {value_int=[1]}\n",
-       "     1 |  # :anonymous_node:128897554321872\n",
-       "          %\"val_2\"<?,?> ⬅️ ::Shape(%\"val_1\") {start=0}\n",
-       "     2 |  # :anonymous_node:128895578494032\n",
-       "          %\"val_3\"<?,?> ⬅️ ::Size(%\"val_2\")\n",
-       "     3 |  # :anonymous_node:128895578494352\n",
-       "          %\"val_4\"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')}\n",
-       "     4 |  # :anonymous_node:128895578494512\n",
-       "          %\"val_5\"<?,?> ⬅️ ::Equal(%\"val_3\", %\"val_4\")\n",
-       "     5 |  # :anonymous_node:128895578494992\n",
-       "          %\"val_6\"<FLOAT,[5,5]> ⬅️ ::ReduceMean(%\"input_0\", %\"val_1\") {keepdims=0, noop_with_empty_axes=0}\n",
-       "     6 |  # :anonymous_node:128895578495312\n",
-       "          %\"val_7\"<?,?> ⬅️ ::ReduceMean(%\"input_0\", %\"val_1\") {keepdims=1, noop_with_empty_axes=0}\n",
-       "     7 |  # :anonymous_node:128895578495472\n",
-       "          %\"val_8\"<?,?> ⬅️ ::Shape(%\"input_0\") {start=0}\n",
-       "     8 |  # :anonymous_node:128895578495632\n",
-       "          %\"val_9\"<?,?> ⬅️ ::Gather(%\"val_8\", %\"val_1\") {axis=0}\n",
-       "     9 |  # :anonymous_node:128895578495952\n",
-       "          %\"val_10\"<?,?> ⬅️ ::ReduceProd(%\"val_9\") {keepdims=0, noop_with_empty_axes=0}\n",
-       "    10 |  # :anonymous_node:128895578496272\n",
-       "          %\"val_11\"<?,?> ⬅️ ::Sub(%\"input_0\", %\"val_7\")\n",
-       "    11 |  # :anonymous_node:128895578496592\n",
-       "          %\"val_12\"<?,?> ⬅️ ::Mul(%\"val_11\", %\"val_11\")\n",
-       "    12 |  # :anonymous_node:128895578497072\n",
-       "          %\"val_13\"<?,?> ⬅️ ::ReduceMean(%\"val_12\", %\"val_1\") {keepdims=0, noop_with_empty_axes=0}\n",
-       "    13 |  # :anonymous_node:128895578497712\n",
-       "          %\"val_14\"<?,?> ⬅️ ::Cast(%\"val_10\") {to=1}\n",
-       "    14 |  # :anonymous_node:128895578498192\n",
-       "          %\"val_15\"<?,?> ⬅️ ::Mul(%\"val_13\", %\"val_14\")\n",
-       "    15 |  # :anonymous_node:128895578498672\n",
-       "          %\"val_16\"<?,?> ⬅️ ::Constant() {value=TensorProtoTensor<INT64,[]>(name='')}\n",
-       "    16 |  # :anonymous_node:128895578498832\n",
-       "          %\"val_17\"<?,?> ⬅️ ::Sub(%\"val_10\", %\"val_16\")\n",
-       "    17 |  # :anonymous_node:128895578499152\n",
-       "          %\"val_18\"<?,?> ⬅️ ::Cast(%\"val_17\") {to=1}\n",
-       "    18 |  # :anonymous_node:128895578499632\n",
-       "          %\"val_19\"<FLOAT,[5,5]> ⬅️ ::Div(%\"val_15\", %\"val_18\")\n",
-       "    return %\"val_19\"<FLOAT,[5,5]>, %\"val_6\"<FLOAT,[5,5]>\n",
-       "}\n",
-       "
\n" - ], - "text/plain": [ - "\u001b[1;35mgraph\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mname\u001b[0m=\u001b[35mtorch_jit\u001b[0m,\n", - " \u001b[33minputs\u001b[0m=\u001b[1m(\u001b[0m\n", - " %\u001b[32m\"input_0\"\u001b[0m\u001b[1m<\u001b[0m\u001b[1;95mFLOAT\u001b[0m\u001b[39m,\u001b[0m\u001b[1;39m[\u001b[0m\u001b[1;36m5\u001b[0m\u001b[39m,\u001b[0m\u001b[1;36m5\u001b[0m\u001b[39m,\u001b[0m\u001b[1;36m5\u001b[0m\u001b[1;39m]\u001b[0m\u001b[39m>\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m,\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[33moutputs\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m(\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_19\"\u001b[0m\u001b[39m,\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_6\"\u001b[0m\u001b[39m\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m,\u001b[0m\n", - "\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m97555281104\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mConstant\u001b[0m\u001b[1;39m(\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mvalue_int\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m[\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;39m]\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m1\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m97554321872\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_2\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mShape\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mstart\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m2\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578494032\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_3\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mSize\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_2\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m3\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578494352\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_4\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mConstant\u001b[0m\u001b[1;39m(\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mvalue\u001b[0m\u001b[39m=\u001b[0m\u001b[35mTensorProtoTensor\u001b[0m\u001b[39m\u001b[0m\u001b[1;39m(\u001b[0m\u001b[33mname\u001b[0m\u001b[39m=\u001b[0m\u001b[32m''\u001b[0m\u001b[1;39m)\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m4\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578494512\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_5\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mEqual\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_3\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_4\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m5\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578494992\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_6\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mReduceMean\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"input_0\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mkeepdims\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m, \u001b[0m\u001b[33mnoop_with_empty_axes\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m6\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578495312\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_7\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mReduceMean\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"input_0\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mkeepdims\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[39m, \u001b[0m\u001b[33mnoop_with_empty_axes\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m7\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578495472\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_8\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mShape\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"input_0\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mstart\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m8\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578495632\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_9\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mGather\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_8\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33maxis\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m9\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578495952\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_10\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mReduceProd\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_9\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mkeepdims\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m, \u001b[0m\u001b[33mnoop_with_empty_axes\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m10\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578496272\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_11\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mSub\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"input_0\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_7\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m11\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578496592\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_12\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mMul\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_11\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_11\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m12\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578497072\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_13\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mReduceMean\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_12\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_1\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mkeepdims\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[39m, \u001b[0m\u001b[33mnoop_with_empty_axes\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m13\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578497712\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_14\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mCast\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_10\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mto\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m14\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578498192\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_15\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mMul\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_13\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_14\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m15\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578498672\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_16\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mConstant\u001b[0m\u001b[1;39m(\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mvalue\u001b[0m\u001b[39m=\u001b[0m\u001b[35mTensorProtoTensor\u001b[0m\u001b[39m\u001b[0m\u001b[1;39m(\u001b[0m\u001b[33mname\u001b[0m\u001b[39m=\u001b[0m\u001b[32m''\u001b[0m\u001b[1;39m)\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m16\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578498832\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_17\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mSub\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_10\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_16\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m17\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578499152\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_18\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mCast\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_17\"\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[33mto\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;39m}\u001b[0m\n", - "\u001b[39m \u001b[0m\u001b[1;36m18\u001b[0m\u001b[39m | # :anonymous_no\u001b[0m\u001b[1;92mde:1288\u001b[0m\u001b[39m95578499632\u001b[0m\n", - "\u001b[39m %\u001b[0m\u001b[32m\"val_19\"\u001b[0m\u001b[39m ⬅️ ::\u001b[0m\u001b[1;35mDiv\u001b[0m\u001b[1;39m(\u001b[0m\u001b[39m%\u001b[0m\u001b[32m\"val_15\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_18\"\u001b[0m\u001b[1;39m)\u001b[0m\n", - "\u001b[39m return %\u001b[0m\u001b[32m\"val_19\"\u001b[0m\u001b[39m, %\u001b[0m\u001b[32m\"val_6\"\u001b[0m\u001b[39m\u001b[0m\n", - "\u001b[1m}\u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "model.graph.display(\n", - " page=False\n", - ") # Set page=True to use a pager in the terminal so long outputs are scrollable" - ] - }, - { - "cell_type": "markdown", - "id": "cf19aa88-2063-4fee-9dd8-5fdca1dab398", - "metadata": {}, - "source": [ - "Convert from the IR object back to ModelProto" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "3b146b60-602a-4cb1-a5f8-d8d22c2a6a72", - "metadata": {}, - "outputs": [], - "source": [ - "model_proto_back = ir.to_proto(model)" - ] - }, - { - "cell_type": "markdown", - "id": "85a23c5b-81b8-4a73-96e0-c8553712d46f", - "metadata": {}, - "source": [ - "## Next steps\n", - "\n", - "Read the introductions for a more detailed introduction of the IR\n", - "(Documentation in progress 🚧)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "onnx", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/ir/index.md b/docs/ir/index.md index 807dbddb51..ae6b0802b5 100644 --- a/docs/ir/index.md +++ b/docs/ir/index.md @@ -1,23 +1,5 @@ # ONNX IR -An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation. +ONNX IR is now an official ONNX project! Documentation has been migrated to [onnx.ai/ir-py/](https://onnx.ai/ir-py/). -## Features ✨ - -- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them). -- Low memory footprint: mmap'ed external tensors; unified interface for ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size limitation. Zero copies. -- Straightforward access patterns: Access value information and traverse the graph topology at ease. -- Robust mutation: Create as many iterators as you like on the graph while mutating it. -- Speed: Performant graph manipulation, serialization/deserialization to Protobuf. -- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way. -- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format. - -## Get started - -```{toctree} -:maxdepth: 1 - -getting_started -tensors -ir_api/index -``` +You may continue to use `onnxscript.ir` unchanged for compatibility with older (<0.3) versions of ONNX Script. diff --git a/docs/ir/ir_api/core.md b/docs/ir/ir_api/core.md deleted file mode 100644 index ad11a9a751..0000000000 --- a/docs/ir/ir_api/core.md +++ /dev/null @@ -1,65 +0,0 @@ -# onnxscript.ir - -```{eval-rst} -.. automodule::onnxscript.ir -.. currentmodule:: onnxscript -``` - -## Functions and constructors - -```{eval-rst} -.. autosummary:: - :toctree: generated - :template: functiontemplate.rst - :nosignatures: - - ir.load - ir.save - ir.from_proto - ir.from_onnx_text - ir.to_proto - ir.tensor - ir.node -``` - -## Classes - -```{eval-rst} -.. autosummary:: - :toctree: generated - :template: classtemplate_inherited.rst - :nosignatures: - - ir.TensorProtocol - ir.Value - ir.Node - ir.Graph - ir.Model - ir.GraphView - ir.Function - ir.Attr - ir.RefAttr - ir.Shape - ir.SymbolicDim - ir.TypeAndShape - ir.TensorType - ir.SparseTensorType - ir.SequenceType - ir.OptionalType - ir.Tensor - ir.ExternalTensor - ir.StringTensor - ir.LazyTensor -``` - -## Enums - -```{eval-rst} -.. autosummary:: - :toctree: generated - :template: classtemplate.rst - :nosignatures: - - ir.DataType - ir.AttributeType -``` diff --git a/docs/ir/ir_api/index.md b/docs/ir/ir_api/index.md deleted file mode 100644 index c8ed762621..0000000000 --- a/docs/ir/ir_api/index.md +++ /dev/null @@ -1,13 +0,0 @@ -# IR APIs - -```{toctree} -:maxdepth: 1 - -core -ir_convenience -ir_external_data -ir_passes -ir_passes_common -ir_traversal -ir_tape -``` diff --git a/docs/ir/ir_api/ir_convenience.md b/docs/ir/ir_api/ir_convenience.md deleted file mode 100644 index 77f09bfe81..0000000000 --- a/docs/ir/ir_api/ir_convenience.md +++ /dev/null @@ -1,15 +0,0 @@ -# ir.convenience - -```{eval-rst} -.. automodule::onnxscript.ir.convenience -.. currentmodule:: onnxscript.ir.convenience -``` - - -```{eval-rst} -.. autofunction:: convert_attribute -.. autofunction:: convert_attributes -.. autofunction:: replace_all_uses_with -.. autofunction:: replace_nodes_and_values -.. autofunction:: create_value_mapping -``` diff --git a/docs/ir/ir_api/ir_external_data.md b/docs/ir/ir_api/ir_external_data.md deleted file mode 100644 index faf34514f1..0000000000 --- a/docs/ir/ir_api/ir_external_data.md +++ /dev/null @@ -1,20 +0,0 @@ -# ir.external_data - -```{eval-rst} -.. automodule::onnxscript.ir.external_data -.. currentmodule:: onnxscript.ir.external_data -``` - -The `ir.external_data` module provides utilities for handling external data in ONNX models. It enables the conversion of tensors to and from external data files, allowing for efficient storage and manipulation of large tensor data. This is particularly useful for models with large initializers that exceed memory constraints. - -## Functions - -```{eval-rst} -.. autofunction:: load_to_model -.. autofunction:: unload_from_model -.. autofunction:: convert_tensors_to_external -.. autofunction:: convert_tensors_from_external -.. autofunction:: set_base_dir -``` - - diff --git a/docs/ir/ir_api/ir_passes.md b/docs/ir/ir_api/ir_passes.md deleted file mode 100644 index ba759a0aee..0000000000 --- a/docs/ir/ir_api/ir_passes.md +++ /dev/null @@ -1,39 +0,0 @@ -# ir.passes - -```{eval-rst} -.. automodule::onnxscript.ir.passes -.. currentmodule:: onnxscript -``` - -## Use built-in passes - -Common, reusable passes are implemented in `ir.passes.common`. You can use {py:class}`ir.passes.Sequential ` to chain passes or use {py:class}`ir.passes.PassManager ` which supports early stopping if no changes are made. - -## Pass infrastructure - -Inherent {py:class}`ir.passes.InPlacePass ` or {py:class}`ir.passes.FunctionalPass ` to define a pass. You will need to implement the `call` method which returns a {py:class}`ir.passes.PassResult `. - -Alternatively, inherent the base class `ir.passes.PassBase ` and override the two properties `changes_input` and `in_place` to set properties of the pass. - -```{eval-rst} -.. autosummary:: - :toctree: generated - :template: classtemplate.rst - :nosignatures: - - ir.passes.PassBase - ir.passes.InPlacePass - ir.passes.FunctionalPass - ir.passes.Sequential - ir.passes.PassResult - ir.passes.PassManager -``` - -## Errors - -```{eval-rst} -.. autoexception:: onnxscript.ir.passes.InvariantError -.. autoexception:: onnxscript.ir.passes.PreconditionError -.. autoexception:: onnxscript.ir.passes.PostconditionError -.. autoexception:: onnxscript.ir.passes.PassError -``` diff --git a/docs/ir/ir_api/ir_passes_common.md b/docs/ir/ir_api/ir_passes_common.md deleted file mode 100644 index 37740160ce..0000000000 --- a/docs/ir/ir_api/ir_passes_common.md +++ /dev/null @@ -1,12 +0,0 @@ -# ir.passes.common - -Built-in passes provided by the ONNX IR - -```{eval-rst} -.. automodule:: onnxscript.ir.passes.common - :show-inheritance: - :members: - :undoc-members: - :exclude-members: call - -``` diff --git a/docs/ir/ir_api/ir_tape.md b/docs/ir/ir_api/ir_tape.md deleted file mode 100644 index bdfa83d673..0000000000 --- a/docs/ir/ir_api/ir_tape.md +++ /dev/null @@ -1,18 +0,0 @@ -# ir.tape - -```{eval-rst} -.. automodule:: onnxscript.ir.tape -.. currentmodule:: onnxscript.ir.tape -``` - -The `ir.tape` module provides utilities for recording nodes and initializers to construct computational graphs or functions. - -## The `Tape` class - -The `Tape` class is a recorder that collects nodes and initializers created during the construction of a graph or function. It supports creating nodes with single or multiple outputs and registering initializers. - -```{eval-rst} -.. autoclass:: Tape - :members: - :undoc-members: -``` diff --git a/docs/ir/ir_api/ir_traversal.md b/docs/ir/ir_api/ir_traversal.md deleted file mode 100644 index fcb1b6aac7..0000000000 --- a/docs/ir/ir_api/ir_traversal.md +++ /dev/null @@ -1,13 +0,0 @@ -# ir.traversal - -```{eval-rst} -.. automodule:: onnxscript.ir.traversal -.. currentmodule:: onnxscript.ir.traversal -``` - -```{eval-rst} -.. autoclass:: RecursiveGraphIterator - :members: - :undoc-members: - :special-members: -``` diff --git a/docs/ir/tensors.md b/docs/ir/tensors.md deleted file mode 100644 index 1f6c825a01..0000000000 --- a/docs/ir/tensors.md +++ /dev/null @@ -1,330 +0,0 @@ -# Tensor Representation in the IR - -The ONNX IR offers the {py:class}`ir.TensorProtocol ` interface for using different data structures as backing data for tensors. Besides the traditional {py:class}`onnx.TensorProto`, you can use {py:class}`np.ndarray`, {py:class}`torch.Tensor`, {py:class}`jax.Array`, and virtually anything else to represent tensors in the graph. This allows them to be accessed and serialized via the same `TensorProtocol` interface, without incurring additional copies during initialization. - -## The `TensorProtocol` - -{py:class}`ir.TensorProtocol ` defines a read-only interface for representing tensors. A tensor class implementing the interface has attributes like `name`, `shape`, `dtype`, `size`, `nbytes` and `metadata_props` to describe basic properties of the tensor. Additionally, it should implement two methods {py:meth}`numpy ` and {py:meth}`__array__ ` which will produce equivalent NumPy arrays from the backing data. - -:::{note} -When interacting with initializers, constant values and tensor attributes, it is best to assume `TensorProtocol` and only use `isinstance` to check for concrete classes when there is a need. -::: - -## Tensor Classes - -### ir.TensorProtoTensor - -We use the {py:class}`ir.TensorProtoTensor ` as a wrapper around the proto to implement the `ir.TensorProtocol` interface. You can access `shape`, `dtype` etc. as usual. A copy is incurred only when `numpy()` is called. - -:::{note} -Directly initializing an `ir.TensorProtoTensor`, as below, is possible. However, it is usually recommended to use `ir.serde.deserialize_tensor` because it handles all types of `TensorProto`s (`ir.TensorProtoTensor` doesn't handle external tensors, for example). Please refer to [From `TensorProto`s and back](#from-tensorprotos-and-back) for an example. -::: - -```{eval-rst} -.. exec_code:: - - import onnx - from onnxscript import ir - - tensor_proto = onnx.helper.make_tensor("tensor", onnx.TensorProto.INT16, (3,), [1, 2, 3]) - tensor = ir.TensorProtoTensor(tensor_proto) - print("tensor: ", tensor) # TensorProtoTensor(name='tensor') - print("shape: ", tensor.shape) # ir.Shape([3]) - print("dtype: ", tensor.dtype) # ir.DataType.INT16 - print(tensor.raw == tensor_proto) # The raw field is the exact tensor_proto provided at initialization - print("tobytes: ", tensor.tobytes()) # b'\x01\x00\x02\x00\x03\x00' - print("numpy: ", tensor.numpy()) # array([1, 2, 3], dtype=int16) -``` - -### ir.ExternalTensor - -Tensor data stored externally in the disk are typically large and will take up memory when loaded. The {py:class}`ir.ExternalTensor ` class uses memory mapping to avoid loading the tensor into memory. You are able to use the tensor as a normal NumPy array with minimal memory usage. - -Refer to {py:func}`ir.serde.deserialize_tensor ` to find an example on converting an `onnx.TensorProto` to an {py:class}`ir.ExternalTensor `. - -### ir.Tensor - -{py:class}`ir.Tensor ` is a wrapper around NumPy array compatible array objects like {py:class}`np.ndarray` and {py:class}`torch.Tensor`. It is best for creating in-memory tensors without converting it to a `TensorProto` to reduce the conversion overhead. - -:::{tip} -An array object is compatible if it defines the `__array__` method. -::: - -To create a tensor from an array, simply initialize it with an NumPy array - -```python -tensor = ir.Tensor(np.random.rand(1, 2)) -``` - -The initializer will obtain dtype and shape information from the array. - -To create a tensor from objects other than NumPy array, you need to specify the dtype: - -```{eval-rst} -.. exec_code:: - - import torch - from onnxscript import ir - - torch_tensor = torch.tensor([1, 2, 3], dtype=torch.float16) - tensor = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT16) - print(tensor.numpy()) # array([1., 2., 3.], dtype=float16) -``` - -### String Tensor - -Use {py:class}`ir.StringTensor ` to create a string tensor. - - - -### Sparse Tensor - -Sparse tensors are not yet supported, but they are on our roadmap. - -## From `TensorProto`s and back - -In the following scenario, we show how to go from a `TensorProto` to an `ir.Tensor`, run some computation, then turn it back to an `ir.Tensor` and finally `TensorProto` - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - import onnx - import numpy as np - - # 1. Create the TensorProto - proto = onnx.helper.make_tensor( - "tensor", onnx.TensorProto.FLOAT16, [2, 3], [1, 2, 3, 4, 5, 6] - ) - - # 2. Create an IR Tensor from the Protobuf message - tensor = ir.serde.deserialize_tensor(proto) - # Note that we get a TensorProtoTensor that implements the TensorProtocol - print("tensor:", tensor) # TensorProtoTensor(name='tensor') - print("tensor.numpy():", tensor.numpy()) # [[1. 2. 3.] - # [4. 5. 6.]] - print("tensor.tobytes():", tensor.tobytes()) # b'\x00<\x00@\x00B\x00D\x00E\x00F' - - # 3. Do computation using numpy - mean = tensor.numpy().mean(axis=0) - print("mean:", mean) # array([2.5, 3.5, 4.5], dtype=float16) - - # 4. Create a Tensor from the ndarray. Note that we use ir.Tensor - tensor_mean = ir.Tensor(mean) - print("tensor_mean:", tensor_mean) # Tensor(array([2.5, 3.5, 4.5], dtype=float16), name='') - - # 5. Obtain the TensorProto from ir.Tensor - mean_tensor_proto: onnx.TensorProto = ir.serde.serialize_tensor(tensor_mean) - print("mean_tensor_proto:", mean_tensor_proto) - print( - "onnx.numpy_helper.to_array(mean_tensor_proto):", - onnx.numpy_helper.to_array(mean_tensor_proto) - # array([2.5, 3.5, 4.5], dtype=float16) - ) - - # You can obtain the bytes data as well - print("tensor_mean.tobytes():", tensor_mean.tobytes()) - print("Bytes same as proto:", mean_tensor_proto.raw_data == tensor_mean.tobytes()) - - # Explore other methods defined by TensorProtocol: - print("\n# Explore other methods defined by TensorProtocol:") - print("tensor_mean.shape:", tensor_mean.shape) - print("tensor_mean.dtype:", tensor_mean.dtype) - print("tensor_mean.name:", tensor_mean.name) - print("tensor_mean.doc_string:", tensor_mean.doc_string) - print("tensor_mean.raw:", tensor_mean.raw) - print("tensor_mean.metadata_props:", tensor_mean.metadata_props) - print("tensor_mean.size:", tensor_mean.size) - print("tensor_mean.nbytes:", tensor_mean.nbytes) - print("tensor_mean.raw:", tensor_mean.raw) -``` - -## Working with non-native NumPy dtypes: bfloat16, float8, int4 - -`ir.Tensor.numpy()` produces a NumPy array representation of the tensor's value. When the tensor has dtype `BFLOAT16`, `FLOAT8[...]` or `[U]INT4` which are not supported by NumPy, we use dtypes from the `ml_dtypes` package. - -`uint4`/`int4` is always unpacked; **`tobyte()` produces a packed representation** as expected. - -Initialization of `ir.Tensor` requires the NumPy array to follow the following typing constraints, or have a `ml_dtypes` dtype. - -- `int8` for (unpacked) int4, with the sign bit extended to 8 bits. -- `uint8` for (unpacked) uint4. -- `uint8` for 8-bit data types like float8. -- `uint16` for bfloat16. - -The following example shows how to create a `FLOAT8E4M3FN` tensor, transform its values, and create a new tensor to store the transformed values. - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - import numpy as np - - array = np.array([0b1, 0b11], dtype=np.uint8) - # The array is reinterpreted using the ml_dtypes package - tensor = ir.Tensor(array, dtype=ir.DataType.FLOAT8E4M3FN) - print(tensor) # Tensor(array([0.00195312, 0.00585938], dtype='float8_e4m3fn'), name=None) - print("tensor.numpy():", tensor.numpy()) # [0.00195312 0.00585938] - - # Compute - times_100 = tensor.numpy() * np.array(100, dtype=tensor.numpy().dtype) - print("times_100:", times_100) - - # Create a new tensor out of the new value; dtype must be specified - new_tensor = ir.Tensor(times_100.view(np.uint8), dtype=ir.DataType.FLOAT8E4M3FN) - # You can also directly create the tensor from the float8 array without specifying dtype - # new_tensor = ir.Tensor(times_100) - print("new_tensor:", new_tensor) # Tensor(array([0.1875, 0.5625], dtype='float8_e4m3fn'), name=None) - print("new_tensor == times_100", new_tensor.numpy() == times_100) # array([ True, True]) -``` - -## Advanced Usage - -### Subclass `ir.Tensor` for More Efficient Access and Broader `dtype` Support - -{py:class}`ir.Tensor` internally converts any array compatible objects into NumPy arrays to produce the byte representation in `tobytes()`. This can be inefficient due to the additional conversion. It also limits support for dtypes not supported by NumPy like bfloat16, because the `__array__` method would fail. - -To fully support arrays from other frameworks, it is usually a good idea to create specialized classes to handle them. The `TorchTensor` class below demonstrates how you can subclass `ir.Tensor` to handle PyTorch tensors: - -```{eval-rst} -.. exec_code:: - from __future__ import annotations - - import ctypes - - import numpy.typing as npt - import torch - - from onnxscript import ir - - - class TorchTensor(ir.Tensor): - def __init__( - self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None - ): - # Pass the tensor as the raw data to ir.Tensor's constructor - - _TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] = { - torch.bfloat16: ir.DataType.BFLOAT16, - torch.bool: ir.DataType.BOOL, - torch.complex128: ir.DataType.COMPLEX128, - torch.complex64: ir.DataType.COMPLEX64, - torch.float16: ir.DataType.FLOAT16, - torch.float32: ir.DataType.FLOAT, - torch.float64: ir.DataType.DOUBLE, - torch.float8_e4m3fn: ir.DataType.FLOAT8E4M3FN, - torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ, - torch.float8_e5m2: ir.DataType.FLOAT8E5M2, - torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ, - torch.int16: ir.DataType.INT16, - torch.int32: ir.DataType.INT32, - torch.int64: ir.DataType.INT64, - torch.int8: ir.DataType.INT8, - torch.uint8: ir.DataType.UINT8, - torch.uint16: ir.DataType.UINT16, - torch.uint32: ir.DataType.UINT32, - torch.uint64: ir.DataType.UINT64, - } - super().__init__( - tensor, dtype=_TORCH_DTYPE_TO_ONNX[tensor.dtype], name=name, doc_string=doc_string - ) - - def numpy(self) -> npt.NDArray: - self.raw: torch.Tensor - if self.dtype == ir.DataType.BFLOAT16: - return self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) - if self.dtype in { - ir.DataType.FLOAT8E4M3FN, - ir.DataType.FLOAT8E4M3FNUZ, - ir.DataType.FLOAT8E5M2, - ir.DataType.FLOAT8E5M2FNUZ, - }: - return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) - - return self.raw.numpy(force=True) - - def __array__(self, dtype = None, copy: bool | None = None) -> npt.NDArray: - del copy # Unused, but needed for the signature - if dtype is None: - return self.numpy() - return self.numpy().__array__(dtype) - - def tobytes(self) -> bytes: - # Implement tobytes to support native PyTorch types so we can use types like bloat16 - # Reading from memory directly is also more efficient because - # it avoids copying to a NumPy array - import torch._subclasses.fake_tensor - - with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access - # Disable any fake mode so calling detach() etc. will return a real tensor - tensor = self.raw.detach().cpu().contiguous() - - if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): # pylint: disable=protected-access - raise TypeError( - f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " - "with a tensor backed by real data using ONNXProgram.apply_weights() " - "or save the model without initializers by setting include_initializers=False." - ) - - return bytes( - (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address( - tensor.data_ptr() - ) - ) - - # Test the implementation - torch_tensor = torch.tensor([1, 2, 3], dtype=torch.bfloat16) - tensor = TorchTensor(torch_tensor) - print("tensor: ", tensor) - print("numpy: ", tensor.numpy()) - print("tobytes: ", tensor.tobytes()) # b'\x80?\x00@@@' - print("nbytes: ", tensor.nbytes) # 6 -``` - -The `TorchTensor` class above implements `tobytes()` to produce the correct bytes representation for the tensor when it is serialized into an ONNX file / TensorProto. The class also implements the `__array__()` method to return the bit representation for types NumPy does not support. This way analysis passes can still perform computation on these values. - -### Computation with different Frameworks - -Since `ir.Tensor` implements the `__array__` method and `__dlpack__` methods, its content can be shared with computation frameworks without copying. For example: - -```{eval-rst} -.. exec_code:: - - from onnxscript import ir - - # We can call numpy methods directly on ir.Tensor - import numpy as np - print(np.multiply(ir.Tensor(np.array([1, 2])), 42)) # array([42., 84.]) - - # We can transfer arrays to different frameworks - import jax.numpy as jnp - import jax - import torch - - # Create ir.Tensor - jax_array = jnp.array([10., 20.]) - ir_tensor_jax = ir.Tensor(jax_array, dtype=ir.DataType.FLOAT) - torch_tensor = torch.tensor([30., 40.]) - ir_tensor_torch = ir.Tensor(torch_tensor, dtype=ir.DataType.FLOAT) - - # Use numpy for computation - print(np.multiply(ir_tensor_jax, ir_tensor_torch)) # array([300., 800.], dtype=float32) - - # Use jax for computation by calling from_dlpack to transfer the tensor data without copying when the device is the same - jax_array_from_ir = jax.dlpack.from_dlpack(ir_tensor_torch) - print(jax_array_from_ir + jax_array) # [40. 60.] - - # Use PyTorch for computation - torch_tensor_from_ir = torch.from_dlpack(ir_tensor_jax) - print(torch_tensor_from_ir - torch_tensor) # tensor([-20., -20.]) - - # They can all be serialized into TensorProto - proto = ir.serde.serialize_tensor(ir_tensor_jax) - print(type(proto)) # - print(proto) - - # The value is exactly the same as jax_array - print(ir.serde.deserialize_tensor(proto).numpy()) # [10. 20.] -``` - -This is particularly useful if you are creating passes on the graph that requires doing computation on concrete values. You are free to use your favorite frameworks to create the passes. The transformed graph that contains newly created `ir.Tensor`s will be compatible with downstream passes even if they leverage other computation frameworks. From 63b85332dc189fdcbc08e7c534ea38d0ae38ef4a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:34:03 -0700 Subject: [PATCH 04/16] Fix import --- onnxscript/rewriter/ort_fusions/_test_utils.py | 10 ---------- .../rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py | 2 +- onnxscript/rewriter/ort_fusions/gqa_test.py | 2 +- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index 4181fffbf4..a626d77d83 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -3,20 +3,10 @@ from __future__ import annotations import numpy as np -import onnx import onnxruntime import packaging.version import onnxscript.ir as ir -import onnxscript.ir._io as io - - -def _save(model, modelpath): - if isinstance(model, onnx.ModelProto): - onnx.save(model, modelpath) - else: - assert isinstance(model, ir.Model) - io.save(model, modelpath) ORT_VERSION = packaging.version.Version(onnxruntime.__version__) diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py index 9559ca1925..9e3cb54b2a 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py @@ -9,7 +9,7 @@ import onnxscript import onnxscript.ir as ir -import onnxscript.ir.passes.common.shape_inference as shape_inference +import onnx_ir.passes.common.shape_inference as shape_inference import onnxscript.optimizer from onnxscript import FLOAT, INT32, script from onnxscript import opset18 as op diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 4f8f9ab8ba..4303c2d72d 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -12,7 +12,7 @@ import onnxscript import onnxscript.ir as ir -import onnxscript.ir.passes.common.shape_inference as shape_inference +import onnx_ir.passes.common.shape_inference as shape_inference import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op From 7e7a139ad7dc988a5b5c3a70d1fc24cc441d2b45 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:37:48 -0700 Subject: [PATCH 05/16] Fix tests and add dependency --- onnxscript/rewriter/ort_fusions/attention_test.py | 2 +- pyproject.toml | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index b77f1c2a8e..ca991060db 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -172,7 +172,7 @@ def test_whisper_encoder(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - model = shape_inference.infer_shapes(model) + model = common_passes.ShapeInferencePass()(model).model mha_count = xformers.fuse_mha(model) self.assertGreater(mha_count, 0) fused_mha_bias_count = xformers.fuse_mha_bias(model) diff --git a/pyproject.toml b/pyproject.toml index 361ba40aa6..46d89ed768 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ classifiers = [ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", ] -dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "packaging"] +dependencies = [ + "numpy", "onnx>=1.16", "onnx_ir>=0.1", "typing_extensions>=4.10", "ml_dtypes", "packaging" +] [tool.setuptools.packages.find] include = ["onnxscript*"] From 5390f8f49ff564731fbcdfdd5671dd5e46de64ca Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:38:54 -0700 Subject: [PATCH 06/16] Lint --- onnxscript/ir/_tape.py | 8 ++++---- onnxscript/ir/convenience.py | 2 ++ onnxscript/optimizer/_constant_folding.py | 2 +- onnxscript/rewriter/__init__.py | 2 +- onnxscript/rewriter/_rewrite_rule.py | 2 +- onnxscript/rewriter/ort_fusions/_test_utils.py | 1 - onnxscript/rewriter/ort_fusions/attention_test.py | 2 +- .../rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py | 2 +- onnxscript/rewriter/ort_fusions/gqa_test.py | 2 +- onnxscript/rewriter/ort_fusions/mha_test.py | 2 +- 10 files changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 8626983199..ac9507436a 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -4,13 +4,13 @@ from __future__ import annotations -from typing import ( - Any, - Sequence, -) +from typing import TYPE_CHECKING, Any, Sequence from onnx_ir import tape +if TYPE_CHECKING: + import onnx_ir as ir + class Builder(tape.Tape): """An extension of the tape that provides a more convenient API for constructing the IR.""" diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index a0fff386b1..06cebc676d 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -1 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from onnx_ir.convenience import * # type: ignore # noqa: F403 diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index b321003b83..ce7a2a85fd 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -16,8 +16,8 @@ import onnx.reference.ops import onnxscript.ir as ir -from onnxscript.ir import _tape import onnxscript.utils.utils as utils +from onnxscript.ir import _tape DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024 diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index c2613acf55..31f3379df5 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -12,8 +12,8 @@ import onnx -from onnxscript import ir import onnxscript.ir.passes.common as common_passes +from onnxscript import ir from onnxscript.rewriter import ( broadcast_to_matmul, cast_constant_of_shape, diff --git a/onnxscript/rewriter/_rewrite_rule.py b/onnxscript/rewriter/_rewrite_rule.py index 303ee6d3c7..508d1bacdd 100644 --- a/onnxscript/rewriter/_rewrite_rule.py +++ b/onnxscript/rewriter/_rewrite_rule.py @@ -18,7 +18,7 @@ import onnxscript.rewriter._matcher as _matcher import onnxscript.rewriter._pattern_ir as _pattern_ir from onnxscript import ir -from onnxscript.ir import convenience, _tape +from onnxscript.ir import _tape, convenience T = TypeVar("T") diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index a626d77d83..24a68445b7 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -8,7 +8,6 @@ import onnxscript.ir as ir - ORT_VERSION = packaging.version.Version(onnxruntime.__version__) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index ca991060db..bbc4e828c4 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -10,11 +10,11 @@ import onnxscript import onnxscript.ir as ir +import onnxscript.ir.passes.common as common_passes import onnxscript.optimizer import onnxscript.rewriter.ort_fusions._core as xformers from onnxscript import FLOAT, script from onnxscript import opset18 as op -import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py index 9e3cb54b2a..12489ab531 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa_test.py @@ -5,11 +5,11 @@ import unittest import numpy as np +import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort import onnxscript import onnxscript.ir as ir -import onnx_ir.passes.common.shape_inference as shape_inference import onnxscript.optimizer from onnxscript import FLOAT, INT32, script from onnxscript import opset18 as op diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 4303c2d72d..a918616161 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -7,12 +7,12 @@ import numpy as np import onnx +import onnx_ir.passes.common.shape_inference as shape_inference import onnxruntime as ort import torch import onnxscript import onnxscript.ir as ir -import onnx_ir.passes.common.shape_inference as shape_inference import onnxscript.optimizer from onnxscript import FLOAT, script from onnxscript import opset18 as op diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 5e68ebc4c3..540e54cde2 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -6,9 +6,9 @@ import packaging.version +import onnxscript.ir.passes.common as common_passes import onnxscript.optimizer import onnxscript.rewriter.ort_fusions._core as xformers -import onnxscript.ir.passes.common as common_passes from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run from onnxscript.rewriter.ort_fusions.models._smollm_2 import smollm_test_2 from onnxscript.rewriter.ort_fusions.models._whisper_decoder import whisper_decoder_test From a8e683034278f33ce0960e88fa229a2a36b6a93f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 21 May 2025 15:45:26 -0700 Subject: [PATCH 07/16] test --- tests/function_libs/torch_lib/ops_test_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index 8de86e3551..a8889cad6c 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -420,7 +420,7 @@ def add_torchlib_common_imports(model: ir.Model) -> None: is_scalar_func = ir.serde.deserialize_function(common_ops.IsScalar.to_function_proto()) model.functions[rank_func.identifier()] = rank_func model.functions[is_scalar_func.identifier()] = is_scalar_func - removal_pass = onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass() + removal_pass = onnxscript.ir.passes.common.RemoveUnusedFunctionsPass() assert removal_pass.in_place removal_pass(model) From 9beb308de057ab84eb375b9b54cea898ca31685b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 10:08:05 -0700 Subject: [PATCH 08/16] deps --- pyproject.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 46d89ed768..00bd5838ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,12 @@ classifiers = [ "License :: OSI Approved :: MIT License", ] dependencies = [ - "numpy", "onnx>=1.16", "onnx_ir>=0.1", "typing_extensions>=4.10", "ml_dtypes", "packaging" + "ml_dtypes", + "numpy", + "onnx_ir>=0.1,<2", # Expect onnx_ir to have a breaking change in 2.0. If not, extend this range. + "onnx>=1.16", + "packaging", + "typing_extensions>=4.10", ] [tool.setuptools.packages.find] From 3671c831a621e741eaf0489de30343081c32a454 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 10:15:12 -0700 Subject: [PATCH 09/16] Add ONNX_IR in test --- noxfile.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/noxfile.py b/noxfile.py index 7646c6e4e0..bf4a2d1e61 100644 --- a/noxfile.py +++ b/noxfile.py @@ -42,6 +42,8 @@ "packaging", "protobuf", ) +ONNX_IR = "onnx_ir==0.1.0" +ONNX_IR_MAIN = "git+https://github.com/onnx/ir-py.git@main#egg=onnx_ir" @nox.session(tags=["build"]) @@ -59,6 +61,7 @@ def test(session): PYTORCH, TORCHVISON, ONNX, + ONNX_IR, ONNX_RUNTIME, TRANSFORMERS, ) @@ -78,6 +81,7 @@ def test_torch_nightly(session): ) session.install("-r", "requirements/ci/requirements-onnx-weekly.txt") session.install("-r", "requirements/ci/requirements-pytorch-nightly.txt") + session.install(ONNX_IR, "--no-deps") session.install(".", "--no-deps") session.run("pip", "list") session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs) @@ -88,6 +92,7 @@ def test_torch_nightly(session): def test_onnx_weekly(session): """Test with ONNX weekly (preview) build.""" session.install(*COMMON_TEST_DEPENDENCIES, ONNX_RUNTIME, PYTORCH, TORCHVISON, TRANSFORMERS) + session.install(ONNX_IR, "--no-deps") session.install("-r", "requirements/ci/requirements-onnx-weekly.txt") session.install(".", "--no-deps") session.run("pip", "list") @@ -103,6 +108,7 @@ def test_ort_nightly(session): PYTORCH, TORCHVISON, ONNX, + ONNX_IR, TRANSFORMERS, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) @@ -113,22 +119,19 @@ def test_ort_nightly(session): session.run("pytest", "tests", *session.posargs) -@nox.session(tags=["test-experimental-torchlib-tracing"]) -def test_experimental_torchlib_tracing(session): - """Test TorchLib with the experimental TORCHLIB_EXPERIMENTAL_PREFER_TRACING flag on.""" +@nox.session(tags=["test-onnx-ir-git"]) +def test_onnx_ir_git(session): + """Test with ONNX IR Git builds.""" session.install( *COMMON_TEST_DEPENDENCIES, PYTORCH, TORCHVISON, ONNX, + TRANSFORMERS, *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) - session.install("-r", "requirements/ci/requirements-ort-nightly.txt") + session.install(ONNX_IR_MAIN) session.install(".", "--no-deps") session.run("pip", "list") - session.run( - "pytest", - "tests/function_libs/torch_lib/ops_test.py", - *session.posargs, - env={"TORCHLIB_EXPERIMENTAL_PREFER_TRACING": "1"}, - ) + session.run("pytest", "onnxscript", "--doctest-modules", *session.posargs) + session.run("pytest", "tests", *session.posargs) From bf5e10628d3c171130239a6f7de08687897d123d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 10:17:07 -0700 Subject: [PATCH 10/16] test-onnx-ir-git --- .github/workflows/main.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index fb71e3f944..9968cd3365 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -31,6 +31,7 @@ jobs: - py311-torch-nightly - py311-onnx-weekly - py311-ort-nightly + - py311-onnx-ir-git - py310 include: - name: py312 @@ -51,6 +52,9 @@ jobs: - name: py311-ort-nightly python-version: "3.11" nox-tag: test-ort-nightly + - name: py311-onnx-ir-git + python-version: "3.11" + nox-tag: test-onnx-ir-git runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 From 61b871d19b280ab5dc2b4159bb9d4495449e4689 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 15:55:05 -0700 Subject: [PATCH 11/16] test --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index bf4a2d1e61..ec786954c2 100644 --- a/noxfile.py +++ b/noxfile.py @@ -127,8 +127,8 @@ def test_onnx_ir_git(session): PYTORCH, TORCHVISON, ONNX, + ONNX_RUNTIME, TRANSFORMERS, - *ONNX_RUNTIME_NIGHTLY_DEPENDENCIES, ) session.install(ONNX_IR_MAIN) session.install(".", "--no-deps") From 8571e69f035740febb5b17f65ad5fda6b904b926 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 15:57:36 -0700 Subject: [PATCH 12/16] Docs --- README.md | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/README.md b/README.md index bcf6862d7a..adfc3238d0 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,6 @@ models using a subset of Python. ONNX Script is: This repo also covers: -* **ONNX IR:** an in-memory IR that supports the full ONNX spec, designed - for graph construction, analysis and transformation. * **ONNX Script Optimizer:** provides functionality to optimize an ONNX model by performing optimizations and clean-ups such as constant folding, dead code elimination, etc. @@ -152,24 +150,6 @@ result = Hardmax(v) More examples can be found in the [docs/examples](docs/examples) directory. -## ONNX IR - -An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation. - -### Features - -* **Full ONNX spec support:** all valid models representable by ONNX protobuf, - and a subset of invalid models (so you can load and fix them). -* **Low memory footprint:** mmap'ed external tensors; unified interface for - ONNX TensorProto, Numpy arrays and PyTorch Tensors etc. No tensor size - limitation. Zero copies. -* **Straightforward access patterns:** Access value information and traverse the - graph topology at ease. -* **Robust mutation:** Create as many iterators as you like on the graph while mutating it. -* **Speed:** Performant graph manipulation, serialization/deserialization to Protobuf. -* **Pythonic and familiar APIs:** Classes define Pythonic apis and still map to - ONNX protobuf concepts in an intuitive way. - ## ONNX Script Tools ### ONNX Optimizer From c9eb6072d330538e50b63c16ade5636ddb31a7bf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 16:04:01 -0700 Subject: [PATCH 13/16] test --- onnxscript/rewriter/ort_fusions/mha_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index c8c22a1731..8d1c04f970 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -85,8 +85,9 @@ def test_whisper_decoder(self): # Fuse SDPA and MHA sdpa_count = xformers.fuse_sdpa(model) self.assertGreater(sdpa_count, 0) - model = shape_inference.infer_shapes(model) - mha_count = xformers.fuse_mha(model) + model = common_passes.ShapeInferencePass()(model).model + mha_count = xformers.fuse_mha1(model) + mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) onnxscript.optimizer.optimize(model) From f7ffd7e8c46952040c6e3959b502792c4b385c18 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 16:09:49 -0700 Subject: [PATCH 14/16] test --- .../ort_fusions/models/_test_models.py | 30 ------------------- 1 file changed, 30 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/models/_test_models.py b/onnxscript/rewriter/ort_fusions/models/_test_models.py index 64f0c396d2..51613123e1 100644 --- a/onnxscript/rewriter/ort_fusions/models/_test_models.py +++ b/onnxscript/rewriter/ort_fusions/models/_test_models.py @@ -2,17 +2,11 @@ # Licensed under the MIT License. from __future__ import annotations -import os -import tempfile - -import numpy as np -import onnxruntime import torch import transformers from transformers import LlamaConfig import onnxscript.ir as ir -import onnxscript.ir._io as io import onnxscript.optimizer # Create a LlamaConfig object with the desired parameters @@ -96,27 +90,3 @@ def get_ort_inputs(self): return { f"input{i}": input.numpy() for i, input in enumerate(inputs) if input is not None } - - -def _ort_check(model_name: str, model, inputs, expected_outputs, rtol=1e-2, atol=1e-2): - providers = ["CPUExecutionProvider"] - with tempfile.TemporaryDirectory() as temp_dir: - model_path = os.path.join(temp_dir, f"{model_name}.onnx") - io.save(model, model_path) - # Run model - session = onnxruntime.InferenceSession(model_path, providers=providers) - ort_outputs = session.run(None, inputs) - - for i, (baseline_output, optimized_output) in enumerate( - zip(expected_outputs, ort_outputs) - ): - try: - np.testing.assert_equal(baseline_output.shape, optimized_output.shape) - np.testing.assert_allclose( - baseline_output, optimized_output, rtol=rtol, atol=atol - ) - except AssertionError as e: - print( - f"Failed for model {model_name} and output {i} with rtol={rtol} and atol={atol}\n{e}" - ) - raise From 5e96cc188110be24d15ed57f4388f4efaffa6567 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 28 May 2025 18:56:46 -0700 Subject: [PATCH 15/16] pylint --- onnxscript/ir/convenience.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index 06cebc676d..d3b157890d 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -1,3 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pylint: disable=wildcard-import from onnx_ir.convenience import * # type: ignore # noqa: F403 From 318710fe1156ba4ec1805290024d77a63d211808 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 29 May 2025 08:27:40 -0700 Subject: [PATCH 16/16] Format --- onnxscript/ir/_tape.py | 6 +++++- onnxscript/ir/convenience.py | 2 +- onnxscript/rewriter/ort_fusions/_core.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index ac9507436a..79312eaefa 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Optional, Sequence from onnx_ir import tape @@ -12,6 +12,10 @@ import onnx_ir as ir +# A type representing the domains/versions used in creating nodes in IR. +UsedOpsets = set[tuple[str, Optional[int]]] + + class Builder(tape.Tape): """An extension of the tape that provides a more convenient API for constructing the IR.""" diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index d3b157890d..e248a5fa84 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -1,4 +1,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# pylint: disable=wildcard-import +# pylint: disable=wildcard-import,unused-wildcard-import from onnx_ir.convenience import * # type: ignore # noqa: F403 diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 424efb008b..78a74f0e03 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -3,8 +3,8 @@ from __future__ import annotations import onnxscript.ir as ir -import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization import onnxscript.ir.passes.common as common_passes +import onnxscript.rewriter.ort_fusions.shape_optimization as shape_optimization from onnxscript.optimizer import optimize from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import (