1
- """Protocols derived from onnx/onnx.proto3.
2
-
3
- The protocols define read-only interfaces only. Mutating methods are not exposed
4
- to users.
1
+ """Protocols for the ONNX IR.
2
+
3
+ This file defines the interfaces for tools to interact with the IR. The interfaces
4
+ are designed such that tools leveraging the IR can be decoupled from the IR
5
+ implementation. This allows for the implementation to evolve independently of the
6
+ tools.
7
+
8
+ The file contains two sets of interfaces:
9
+ 1. Topologically immutable interfaces:
10
+ These interfaces provide a complete view of the ONNX model and allows mutation
11
+ against any metadata fields like shape, type, and node attributes. However, the
12
+ interfaces are topologically immutable, meaning that the structure of the graph
13
+ cannot be changed. This is useful for tools that need to analyze the model
14
+ without modifying how nodes are connected.
15
+ 2. Mutable interfaces:
16
+ These interfaces provide a mutable view of the ONNX model. They allow for
17
+ modification of the graph structure. This is useful for tools that need to
18
+ transform the model.
5
19
"""
6
20
7
21
from __future__ import annotations
25
39
import numpy as np
26
40
from typing_extensions import TypeAlias
27
41
42
+ # Representation of a dimension. int is a known axis, str represents a dynamic axis, None is an unnamed dynamic axis.
28
43
SimpleDim : TypeAlias = Union [int , str , None ]
44
+ # Representation of a shape. Each element is a simple dimension.
29
45
SimpleShape : TypeAlias = Sequence [SimpleDim ]
30
46
31
47
# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload)
34
50
35
51
@typing .runtime_checkable
36
52
class ArrayCompatible (Protocol ):
37
- """Protocol for array-like objects."""
53
+ """Protocol for array-like objects.
54
+
55
+ An example of an array-like object is a numpy array or a PyTorch array.
56
+ Read more at https://numpy.org/devdocs/user/basics.interoperability.html
57
+ """
38
58
39
59
def __array__ (self , dtype : Any ) -> np .ndarray : ...
40
60
41
61
42
62
@typing .runtime_checkable
43
63
class DLPackCompatible (Protocol ):
44
- """Protocol objects that can support dlpack."""
64
+ """Protocol objects that can support dlpack.
65
+
66
+ Computation backends can call __dlpack__ to obtain the underlying data in a
67
+ tensor without copying the data. This allows use to use tensorflow tensors etc.
68
+ without copying the data.
69
+ """
45
70
46
71
def __dlpack__ (self , * , stream : Any = ...) -> Any :
47
72
"""Return PyCapsule."""
@@ -52,13 +77,25 @@ def __dlpack__(self, *, stream: Any = ...) -> Any:
52
77
class TensorProtocol (ArrayCompatible , Protocol ):
53
78
"""Concrete tensor backed by data.
54
79
80
+ The protocol does not specify how the data is stored. That data is exposed
81
+ through the :attr:`raw` attribute for examination, but accessing :attr:`raw`
82
+ is typically not needed.
83
+
84
+ To use the tensor as a numpy array, call :meth:`numpy`. To convert the tensor
85
+ to a byte string for serialization, call :meth:`tobytes`.
86
+
87
+ It is recommended to check the size of the tensor first before accessing the
88
+ underlying data, because accessing the data may be expensive and incur IO
89
+ overhead.
90
+
55
91
Attributes:
56
92
name: The name of the tensor.
57
93
shape: The shape of the tensor.
58
- dtype: The data type of the elements of the tensor.
94
+ dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
59
95
doc_string: Documentation string.
60
96
raw: The raw data behind this tensor. It can be anything.
61
- value: The tensor as a numpy array.
97
+ size: The number of elements in the tensor.
98
+ nbytes: The number of bytes in the tensor.
62
99
"""
63
100
64
101
name : str
@@ -67,6 +104,12 @@ class TensorProtocol(ArrayCompatible, Protocol):
67
104
doc_string : str | None
68
105
raw : Any
69
106
107
+ @property
108
+ def size (self ) -> int : ...
109
+
110
+ @property
111
+ def nbytes (self ) -> int : ...
112
+
70
113
def numpy (self ) -> np .ndarray :
71
114
"""Return the tensor as a numpy array."""
72
115
...
@@ -82,12 +125,27 @@ def tobytes(self) -> bytes:
82
125
83
126
@typing .runtime_checkable
84
127
class ValueProtocol (Protocol ):
85
- """Protocol for ONNX values.
128
+ """Values.
129
+
130
+ A value is a named entity that can be used to represent an input or output of a graph,
131
+ a function, or a node. The information it stores corresponds to ValueInfoProto
132
+ in the ONNX specification.
133
+
134
+ A :class:`Value` is always not owned or owned by exactly one node. When the value is not
135
+ owned, it must be an input of a graph or a function. The def_node and def_index
136
+ attributes are None.
86
137
87
- A value is a named entity that can be used as an input or output of an operator.
138
+ When the value is owned by a node, it is an output of the node.
139
+ The node that produces the value is stored in the :attr:`def_node` attribute.
140
+ The index of the output of the node that produces the value is stored in the
141
+ :attr:`def_index` attribute.
142
+
143
+ To find all the nodes that use this value as an input, call :meth:`users`.
144
+
145
+ To check if the value is an output of a graph, call :meth:`is_graph_output`.
88
146
89
147
Attributes:
90
- name: The name of the value.
148
+ name: The name of the value. A value is always named when it is part of a graph.
91
149
def_node: The node that produces this value.
92
150
def_index: The index of the output of the node that produces this value.
93
151
shape: The shape of the value.
@@ -113,39 +171,68 @@ def is_graph_output(self) -> bool:
113
171
114
172
@typing .runtime_checkable
115
173
class NodeProtocol (Protocol ):
116
- """Protocol for ONNX nodes.
174
+ """Nodes.
175
+
176
+ A node represents an invocation of an operation on the :class:`Value` s in
177
+ the computational graph.
117
178
118
- A node represents an operation in the computation graph.
179
+ A node can be optionally named. A name should typically be assigned when the
180
+ node is added to a graph.
181
+
182
+ :attr:`domain`, :attr:`op_type`, and :attr:`overload` together uniquely identify
183
+ the operator, and are always strings. For ONNX operators, :attr:`domain` and :attr:`overload`
184
+ are both empty strings.
185
+
186
+ :attr:`inputs` and :attr:`outputs` are the input and output values of the node.
187
+
188
+ :attr:`attributes` are the attributes of the node. The attributes are stored in an
189
+ ordered dictionary to preserve the order of the attributes. This is a deviation from
190
+ the current ONNX spec where attributes are unordered, but it is helpful for tools
191
+ that rely on the order of the attributes, e.g. those converting to and from Python
192
+ function keyword arguments.
193
+
194
+ :attr:`version` is unique to the IR and is not specified in the ONNX spec. This
195
+ allows the IR to represent a graph with mixed opset versions. Deserializers
196
+ should decide how to reconcile the different versions within the graph. A typical
197
+ graph will have a single version, declared in the :class:`Graph` object and
198
+ the nodes will have ``None`` as the version.
119
199
120
200
Attributes:
121
- domain: The domain of the operator. E.g. "" for ONNX operators.
122
- version: The version of the operator.
201
+ domain: The domain of the operator. E.g. ``""`` for ONNX operators.
123
202
op_type: The operator name.
124
203
overload: The overload name when the node is invoking a function.
125
204
inputs: Input values.
126
205
outputs: Output values.
127
206
attributes: The attributes of the operator.
207
+ version: The version of the operator.
128
208
doc_string: Documentation string.
129
209
metadata_props: Metadata.
130
210
"""
131
211
132
212
name : str | None
133
213
domain : str
134
- version : int | None
135
214
op_type : str
136
215
overload : str
137
216
inputs : Sequence [ValueProtocol ]
138
217
outputs : Sequence [ValueProtocol ]
139
218
attributes : OrderedDict [str , AttributeProtocol | ReferenceAttributeProtocol ]
219
+ version : int | None
140
220
doc_string : str | None
141
221
metadata_props : Mapping [str , str ]
142
222
143
223
144
224
@typing .runtime_checkable
145
225
class GraphProtocol (Protocol ):
146
- """Protocol for ONNX graphs.
226
+ """Graphs.
227
+
228
+ Graph represents a computation graph. In addition to the ONNX specification
229
+ specified fields, it also contains a mapping of :attr:`opset_imports`. This
230
+ allows different subgraphs to import different opsets. It is the responsibility
231
+ of the deserializer to reconcile the different opsets.
147
232
148
- Graph represents a computation graph.
233
+ The :attr:`nodes` are not guaranteed to be topologically sorted. But the
234
+ iteration order should be deterministic across different runs. It is the
235
+ responsibility of the user to maintain a topological order of the nodes.
149
236
150
237
Attributes:
151
238
name: The name of the graph.
@@ -158,7 +245,7 @@ class GraphProtocol(Protocol):
158
245
metadata_props: Metadata.
159
246
"""
160
247
161
- # TODO(justinchuby): Support quantization_annotation and metadata_props
248
+ # TODO(justinchuby): Support quantization_annotation
162
249
name : str | None
163
250
inputs : Sequence [ValueProtocol ]
164
251
outputs : Sequence [ValueProtocol ]
@@ -175,9 +262,10 @@ def topologically_sorted_nodes(self) -> Sequence[NodeProtocol]:
175
262
176
263
@typing .runtime_checkable
177
264
class ModelProtocol (Protocol ):
178
- """Protocol for ONNX models .
265
+ """Models .
179
266
180
- A model is a container for a graph and metadata.
267
+ A model is a container for a graph and metadata. It is the top-level object
268
+ that represents an ONNX model.
181
269
182
270
Attributes:
183
271
graph: The graph of the model.
@@ -225,6 +313,8 @@ class AttributeProtocol(Protocol):
225
313
class ReferenceAttributeProtocol (Protocol ):
226
314
"""Protocol for a reference attribute.
227
315
316
+ A reference attribute can only appear inside the definition body of a function.
317
+
228
318
Attributes:
229
319
name: The name of the attribute.
230
320
ref_attr_name: The name of the attribute definition this attribute refers to.
@@ -312,6 +402,11 @@ def __eq__(self, __value: object) -> bool: ...
312
402
313
403
@typing .runtime_checkable
314
404
class MapTypeProtocol (Protocol ):
405
+ """Protocol for ONNX map types.
406
+
407
+ TODO: This protocol is not yet implemented in the ONNX IR.
408
+ """
409
+
315
410
key_type : typing .Literal [
316
411
_enums .DataType .STRING ,
317
412
_enums .DataType .INT64 ,
@@ -330,6 +425,9 @@ class MapTypeProtocol(Protocol):
330
425
class FunctionProtocol (Protocol ):
331
426
"""Protocol for ONNX functions.
332
427
428
+ Like a graph, a function can have nodes that are not topologically sorted. It is
429
+ the responsibility of the user to maintain a topological order of the nodes.
430
+
333
431
Attributes:
334
432
name: The function name.
335
433
domain: The domain this function is defined in.
@@ -350,10 +448,6 @@ class FunctionProtocol(Protocol):
350
448
attributes : OrderedDict [str , AttributeProtocol ]
351
449
outputs : Sequence [ValueProtocol ]
352
450
doc_string : str
353
- # opset_import is stored in a model, not a graph. However,
354
- # In ONNX IR we store it in a graph to unify it with
355
- # the function. This way a materialized function can still
356
- # be used as a subgraph even if it imports a different opset.
357
451
opset_imports : Mapping [str , int ]
358
452
nodes : Sequence [NodeProtocol ]
359
453
metadata_props : Mapping [str , str ]
0 commit comments