Skip to content

Commit 9f86b53

Browse files
justinchubyCopilot
authored andcommitted
[IR] Convenience constructor for Node (microsoft#2126)
Create a convenience constructor for `Node`. Refactor the constructors to a separate module. ## Motivation Currently users when interacting with the IR needs to use the raw `ir.Node` constructor for creating nodes. This constructor is designed for performance and not ease-of-use. For users I created a new `ir.node` that exposes a more natural calling style that supports plain python values as attributes and an optional `domain` argument. --------- Co-authored-by: Copilot <[email protected]>
1 parent a879613 commit 9f86b53

File tree

5 files changed

+192
-100
lines changed

5 files changed

+192
-100
lines changed

onnxscript/ir/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@
7070
# Conversion functions
7171
"from_proto",
7272
"to_proto",
73-
# IR Tensor initializer
73+
# Convenience constructors
7474
"tensor",
75+
"node",
7576
# Pass infrastructure
7677
"passes",
7778
# IO
@@ -80,7 +81,7 @@
8081
]
8182

8283
from onnxscript.ir import convenience, external_data, passes, serde, traversal
83-
from onnxscript.ir._convenience import tensor
84+
from onnxscript.ir._convenience._constructors import node, tensor
8485
from onnxscript.ir._core import (
8586
Attr,
8687
AttrFloat32,

onnxscript/ir/_convenience.py renamed to onnxscript/ir/_convenience/__init__.py

Lines changed: 3 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,15 @@
1212
"convert_attribute",
1313
"convert_attributes",
1414
"replace_all_uses_with",
15+
"create_value_mapping",
16+
"replace_nodes_and_values",
1517
]
1618

17-
import typing
1819
from typing import Mapping, Sequence, Union
1920

20-
import numpy as np
2121
import onnx
2222

23-
from onnxscript.ir import _core, _enums, _protocols, serde, tensor_adapters
24-
25-
if typing.TYPE_CHECKING:
26-
import numpy.typing as npt
23+
from onnxscript.ir import _core, _enums, _protocols, serde
2724

2825
SupportedAttrTypes = Union[
2926
str,
@@ -291,94 +288,6 @@ def replace_all_uses_with(
291288
user_node.replace_input_with(index, replacement)
292289

293290

294-
def tensor(
295-
value: npt.ArrayLike
296-
| onnx.TensorProto
297-
| _protocols.DLPackCompatible
298-
| _protocols.ArrayCompatible,
299-
dtype: _enums.DataType | None = None,
300-
name: str | None = None,
301-
doc_string: str | None = None,
302-
) -> _protocols.TensorProtocol:
303-
"""Create a tensor value from an ArrayLike object or a TensorProto.
304-
305-
The dtype must match the value. Reinterpretation of the value is
306-
not supported, unless if the value is a plain Python object, in which case
307-
it is converted to a numpy array with the given dtype.
308-
309-
:param:`value` can be a numpy array, a plain Python object, or a TensorProto.
310-
311-
Example::
312-
313-
>>> from onnxscript import ir
314-
>>> import numpy as np
315-
>>> import ml_dtypes
316-
>>> import onnx
317-
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
318-
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
319-
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
320-
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
321-
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
322-
>>> tp_tensor.numpy()
323-
array(0.5, dtype=float32)
324-
>>> import torch
325-
>>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
326-
TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
327-
328-
Args:
329-
value: The numpy array to create the tensor from.
330-
dtype: The data type of the tensor.
331-
name: The name of the tensor.
332-
doc_string: The documentation string of the tensor.
333-
334-
Returns:
335-
A tensor value.
336-
337-
Raises:
338-
ValueError: If the dtype does not match the value when value is not a plain Python
339-
object like ``list[int]``.
340-
"""
341-
if isinstance(value, _protocols.TensorProtocol):
342-
if dtype is not None and dtype != value.dtype:
343-
raise ValueError(
344-
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
345-
"You do not have to specify the dtype when value is a Tensor."
346-
)
347-
return value
348-
if isinstance(value, onnx.TensorProto):
349-
tensor_ = serde.deserialize_tensor(value)
350-
if name is not None:
351-
tensor_.name = name
352-
if doc_string is not None:
353-
tensor_.doc_string = doc_string
354-
if dtype is not None and dtype != tensor_.dtype:
355-
raise ValueError(
356-
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
357-
"You do not have to specify the dtype when value is a TensorProto."
358-
)
359-
return tensor_
360-
elif str(type(value)) == "<class 'torch.Tensor'>":
361-
# NOTE: We use str(type(...)) and do not import torch for type checking
362-
# as it creates overhead during import
363-
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
364-
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
365-
return _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
366-
367-
# Plain Python object
368-
if dtype is not None:
369-
numpy_dtype = dtype.numpy()
370-
else:
371-
numpy_dtype = None
372-
array = np.array(value, dtype=numpy_dtype)
373-
return _core.Tensor(
374-
array,
375-
dtype=dtype,
376-
shape=_core.Shape(array.shape),
377-
name=name,
378-
doc_string=name,
379-
)
380-
381-
382291
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
383292
"""Return a dictionary mapping names to values in the graph.
384293
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Convenience constructors for IR objects."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"tensor",
9+
"node",
10+
]
11+
12+
import typing
13+
from typing import Mapping, Sequence
14+
15+
import numpy as np
16+
import onnx
17+
18+
from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
19+
20+
if typing.TYPE_CHECKING:
21+
import numpy.typing as npt
22+
23+
from onnxscript import ir
24+
25+
26+
def tensor(
27+
value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
28+
dtype: _enums.DataType | None = None,
29+
name: str | None = None,
30+
doc_string: str | None = None,
31+
) -> _protocols.TensorProtocol:
32+
"""Create a tensor value from an ArrayLike object or a TensorProto.
33+
34+
The dtype must match the value. Reinterpretation of the value is
35+
not supported, unless if the value is a plain Python object, in which case
36+
it is converted to a numpy array with the given dtype.
37+
38+
``value`` can be a numpy array, a plain Python object, or a TensorProto.
39+
40+
Example::
41+
42+
>>> from onnxscript import ir
43+
>>> import numpy as np
44+
>>> import ml_dtypes
45+
>>> import onnx
46+
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
47+
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
48+
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
49+
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
50+
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
51+
>>> tp_tensor.numpy()
52+
array(0.5, dtype=float32)
53+
>>> import torch
54+
>>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
55+
TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
56+
57+
Args:
58+
value: The numpy array to create the tensor from.
59+
dtype: The data type of the tensor.
60+
name: The name of the tensor.
61+
doc_string: The documentation string of the tensor.
62+
63+
Returns:
64+
A tensor value.
65+
66+
Raises:
67+
ValueError: If the dtype does not match the value when value is not a plain Python
68+
object like ``list[int]``.
69+
"""
70+
if isinstance(value, _protocols.TensorProtocol):
71+
if dtype is not None and dtype != value.dtype:
72+
raise ValueError(
73+
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
74+
"You do not have to specify the dtype when value is a Tensor."
75+
)
76+
return value
77+
if isinstance(value, onnx.TensorProto):
78+
tensor_ = serde.deserialize_tensor(value)
79+
if name is not None:
80+
tensor_.name = name
81+
if doc_string is not None:
82+
tensor_.doc_string = doc_string
83+
if dtype is not None and dtype != tensor_.dtype:
84+
raise ValueError(
85+
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
86+
"You do not have to specify the dtype when value is a TensorProto."
87+
)
88+
return tensor_
89+
elif str(type(value)) == "<class 'torch.Tensor'>":
90+
# NOTE: We use str(type(...)) and do not import torch for type checking
91+
# as it creates overhead during import
92+
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
93+
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
94+
return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
95+
# Plain Python object
96+
if dtype is not None:
97+
numpy_dtype = dtype.numpy()
98+
else:
99+
numpy_dtype = None
100+
array = np.array(value, dtype=numpy_dtype)
101+
return _core.Tensor(
102+
array,
103+
dtype=dtype,
104+
shape=_core.Shape(array.shape),
105+
name=name,
106+
doc_string=doc_string,
107+
)
108+
109+
110+
def node(
111+
op_type: str,
112+
inputs: Sequence[ir.Value],
113+
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
114+
*,
115+
domain: str = "",
116+
overload: str = "",
117+
num_outputs: int | None = None,
118+
outputs: Sequence[ir.Value] | None = None,
119+
version: int | None = None,
120+
graph: ir.Graph | None = None,
121+
name: str | None = None,
122+
doc_string: str | None = None,
123+
metadata_props: dict[str, str] | None = None,
124+
) -> ir.Node:
125+
"""Create an :class:`ir.Node`.
126+
127+
This is a convenience constructor for creating a Node that supports Python
128+
objects as attributes.
129+
130+
Example::
131+
132+
>>> from onnxscript import ir
133+
>>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
134+
>>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
135+
>>> node = ir.node(
136+
... "SomeOp",
137+
... inputs=[input_a, input_b],
138+
... attributes={"alpha": 1.0, "some_list": [1, 2, 3]},
139+
... domain="some.domain",
140+
... name="node_name"
141+
... )
142+
>>> node.op_type
143+
'SomeOp'
144+
145+
Args:
146+
op_type: The name of the operator.
147+
inputs: The input values. When an input is None, it is an empty input.
148+
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
149+
overload: The overload name when the node is invoking a function.
150+
domain: The domain of the operator. For onnx operators, this is an empty string.
151+
num_outputs: The number of outputs of the node. If not specified, the number is 1.
152+
outputs: The output values. If None, the outputs are created during initialization.
153+
version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
154+
graph: The graph that the node belongs to. If None, the node is not added to any graph.
155+
A `Node` must belong to zero or one graph.
156+
name: The name of the node. If None, the node is anonymous.
157+
doc_string: The documentation string.
158+
metadata_props: The metadata properties.
159+
160+
Returns:
161+
A node with the given op_type and inputs.
162+
"""
163+
if attributes is None:
164+
attrs: Sequence[ir.Attr | ir.RefAttr] = ()
165+
else:
166+
attrs = _convenience.convert_attributes(attributes)
167+
return _core.Node(
168+
domain=domain,
169+
op_type=op_type,
170+
inputs=inputs,
171+
attributes=attrs,
172+
overload=overload,
173+
num_outputs=num_outputs,
174+
outputs=outputs,
175+
version=version,
176+
graph=graph,
177+
name=name,
178+
doc_string=doc_string,
179+
metadata_props=metadata_props,
180+
)

onnxscript/ir/_convenience_test.py renamed to onnxscript/ir/_convenience/_constructors_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
"""Unit tests for the _convenience module."""
3+
"""Unit tests for the _constructors module."""
44

55
import unittest
66

77
import numpy as np
88

9-
from onnxscript.ir import _convenience
9+
from onnxscript.ir._convenience import _constructors
1010

1111

12-
class ConvenienceTest(unittest.TestCase):
12+
class ConstructorsTest(unittest.TestCase):
1313
def test_tensor_accepts_torch_tensor(self):
1414
import torch as some_random_name # pylint: disable=import-outside-toplevel
1515

1616
torch_tensor = some_random_name.tensor([1, 2, 3])
17-
tensor = _convenience.tensor(torch_tensor)
17+
tensor = _constructors.tensor(torch_tensor)
1818
np.testing.assert_array_equal(tensor, torch_tensor.numpy())
1919

2020

onnxscript/ir/convenience.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
"convert_attributes",
1010
"replace_all_uses_with",
1111
"replace_nodes_and_values",
12+
"create_value_mapping",
1213
]
1314

1415
from onnxscript.ir._convenience import (
1516
convert_attribute,
1617
convert_attributes,
18+
create_value_mapping,
1719
replace_all_uses_with,
1820
replace_nodes_and_values,
1921
)

0 commit comments

Comments
 (0)