|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""Convenience constructors for IR objects.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +__all__ = [ |
| 8 | + "tensor", |
| 9 | + "node", |
| 10 | +] |
| 11 | + |
| 12 | +import typing |
| 13 | +from typing import Mapping, Sequence |
| 14 | + |
| 15 | +import numpy as np |
| 16 | +import onnx |
| 17 | + |
| 18 | +from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters |
| 19 | + |
| 20 | +if typing.TYPE_CHECKING: |
| 21 | + import numpy.typing as npt |
| 22 | + |
| 23 | + from onnxscript import ir |
| 24 | + |
| 25 | + |
| 26 | +def tensor( |
| 27 | + value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible, |
| 28 | + dtype: _enums.DataType | None = None, |
| 29 | + name: str | None = None, |
| 30 | + doc_string: str | None = None, |
| 31 | +) -> _protocols.TensorProtocol: |
| 32 | + """Create a tensor value from an ArrayLike object or a TensorProto. |
| 33 | +
|
| 34 | + The dtype must match the value. Reinterpretation of the value is |
| 35 | + not supported, unless if the value is a plain Python object, in which case |
| 36 | + it is converted to a numpy array with the given dtype. |
| 37 | +
|
| 38 | + ``value`` can be a numpy array, a plain Python object, or a TensorProto. |
| 39 | +
|
| 40 | + Example:: |
| 41 | +
|
| 42 | + >>> from onnxscript import ir |
| 43 | + >>> import numpy as np |
| 44 | + >>> import ml_dtypes |
| 45 | + >>> import onnx |
| 46 | + >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16)) |
| 47 | + Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None) |
| 48 | + >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16) |
| 49 | + Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None) |
| 50 | + >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5])) |
| 51 | + >>> tp_tensor.numpy() |
| 52 | + array(0.5, dtype=float32) |
| 53 | + >>> import torch |
| 54 | + >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor") |
| 55 | + TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor') |
| 56 | +
|
| 57 | + Args: |
| 58 | + value: The numpy array to create the tensor from. |
| 59 | + dtype: The data type of the tensor. |
| 60 | + name: The name of the tensor. |
| 61 | + doc_string: The documentation string of the tensor. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + A tensor value. |
| 65 | +
|
| 66 | + Raises: |
| 67 | + ValueError: If the dtype does not match the value when value is not a plain Python |
| 68 | + object like ``list[int]``. |
| 69 | + """ |
| 70 | + if isinstance(value, _protocols.TensorProtocol): |
| 71 | + if dtype is not None and dtype != value.dtype: |
| 72 | + raise ValueError( |
| 73 | + f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. " |
| 74 | + "You do not have to specify the dtype when value is a Tensor." |
| 75 | + ) |
| 76 | + return value |
| 77 | + if isinstance(value, onnx.TensorProto): |
| 78 | + tensor_ = serde.deserialize_tensor(value) |
| 79 | + if name is not None: |
| 80 | + tensor_.name = name |
| 81 | + if doc_string is not None: |
| 82 | + tensor_.doc_string = doc_string |
| 83 | + if dtype is not None and dtype != tensor_.dtype: |
| 84 | + raise ValueError( |
| 85 | + f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}" |
| 86 | + "You do not have to specify the dtype when value is a TensorProto." |
| 87 | + ) |
| 88 | + return tensor_ |
| 89 | + elif str(type(value)) == "<class 'torch.Tensor'>": |
| 90 | + # NOTE: We use str(type(...)) and do not import torch for type checking |
| 91 | + # as it creates overhead during import |
| 92 | + return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type] |
| 93 | + elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)): |
| 94 | + return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string) |
| 95 | + # Plain Python object |
| 96 | + if dtype is not None: |
| 97 | + numpy_dtype = dtype.numpy() |
| 98 | + else: |
| 99 | + numpy_dtype = None |
| 100 | + array = np.array(value, dtype=numpy_dtype) |
| 101 | + return _core.Tensor( |
| 102 | + array, |
| 103 | + dtype=dtype, |
| 104 | + shape=_core.Shape(array.shape), |
| 105 | + name=name, |
| 106 | + doc_string=doc_string, |
| 107 | + ) |
| 108 | + |
| 109 | + |
| 110 | +def node( |
| 111 | + op_type: str, |
| 112 | + inputs: Sequence[ir.Value], |
| 113 | + attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None, |
| 114 | + *, |
| 115 | + domain: str = "", |
| 116 | + overload: str = "", |
| 117 | + num_outputs: int | None = None, |
| 118 | + outputs: Sequence[ir.Value] | None = None, |
| 119 | + version: int | None = None, |
| 120 | + graph: ir.Graph | None = None, |
| 121 | + name: str | None = None, |
| 122 | + doc_string: str | None = None, |
| 123 | + metadata_props: dict[str, str] | None = None, |
| 124 | +) -> ir.Node: |
| 125 | + """Create an :class:`ir.Node`. |
| 126 | +
|
| 127 | + This is a convenience constructor for creating a Node that supports Python |
| 128 | + objects as attributes. |
| 129 | +
|
| 130 | + Example:: |
| 131 | +
|
| 132 | + >>> from onnxscript import ir |
| 133 | + >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) |
| 134 | + >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) |
| 135 | + >>> node = ir.node( |
| 136 | + ... "SomeOp", |
| 137 | + ... inputs=[input_a, input_b], |
| 138 | + ... attributes={"alpha": 1.0, "some_list": [1, 2, 3]}, |
| 139 | + ... domain="some.domain", |
| 140 | + ... name="node_name" |
| 141 | + ... ) |
| 142 | + >>> node.op_type |
| 143 | + 'SomeOp' |
| 144 | +
|
| 145 | + Args: |
| 146 | + op_type: The name of the operator. |
| 147 | + inputs: The input values. When an input is None, it is an empty input. |
| 148 | + attributes: The attributes. RefAttr can be used only when the node is defined in a Function. |
| 149 | + overload: The overload name when the node is invoking a function. |
| 150 | + domain: The domain of the operator. For onnx operators, this is an empty string. |
| 151 | + num_outputs: The number of outputs of the node. If not specified, the number is 1. |
| 152 | + outputs: The output values. If None, the outputs are created during initialization. |
| 153 | + version: The version of the operator. If None, the version is unspecified and will follow that of the graph. |
| 154 | + graph: The graph that the node belongs to. If None, the node is not added to any graph. |
| 155 | + A `Node` must belong to zero or one graph. |
| 156 | + name: The name of the node. If None, the node is anonymous. |
| 157 | + doc_string: The documentation string. |
| 158 | + metadata_props: The metadata properties. |
| 159 | +
|
| 160 | + Returns: |
| 161 | + A node with the given op_type and inputs. |
| 162 | + """ |
| 163 | + if attributes is None: |
| 164 | + attrs: Sequence[ir.Attr | ir.RefAttr] = () |
| 165 | + else: |
| 166 | + attrs = _convenience.convert_attributes(attributes) |
| 167 | + return _core.Node( |
| 168 | + domain=domain, |
| 169 | + op_type=op_type, |
| 170 | + inputs=inputs, |
| 171 | + attributes=attrs, |
| 172 | + overload=overload, |
| 173 | + num_outputs=num_outputs, |
| 174 | + outputs=outputs, |
| 175 | + version=version, |
| 176 | + graph=graph, |
| 177 | + name=name, |
| 178 | + doc_string=doc_string, |
| 179 | + metadata_props=metadata_props, |
| 180 | + ) |
0 commit comments