1- """Protocols derived from onnx/onnx.proto3.
2-
3- The protocols define read-only interfaces only. Mutating methods are not exposed
4- to users.
1+ """Protocols for the ONNX IR.
2+
3+ This file defines the interfaces for tools to interact with the IR. The interfaces
4+ are designed such that tools leveraging the IR can be decoupled from the IR
5+ implementation. This allows for the implementation to evolve independently of the
6+ tools.
7+
8+ The file contains two sets of interfaces:
9+ 1. Topologically immutable interfaces:
10+ These interfaces provide a complete view of the ONNX model and allows mutation
11+ against any metadata fields like shape, type, and node attributes. However, the
12+ interfaces are topologically immutable, meaning that the structure of the graph
13+ cannot be changed. This is useful for tools that need to analyze the model
14+ without modifying how nodes are connected.
15+ 2. Mutable interfaces:
16+ These interfaces provide a mutable view of the ONNX model. They allow for
17+ modification of the graph structure. This is useful for tools that need to
18+ transform the model.
519"""
620
721from __future__ import annotations
2539 import numpy as np
2640 from typing_extensions import TypeAlias
2741
42+ # Representation of a dimension. int is a known axis, str represents a dynamic axis, None is an unnamed dynamic axis.
2843SimpleDim : TypeAlias = Union [int , str , None ]
44+ # Representation of a shape. Each element is a simple dimension.
2945SimpleShape : TypeAlias = Sequence [SimpleDim ]
3046
3147# An identifier that will uniquely identify an operator. E.g (domain, op_type, overload)
3450
3551@typing .runtime_checkable
3652class ArrayCompatible (Protocol ):
37- """Protocol for array-like objects."""
53+ """Protocol for array-like objects.
54+
55+ An example of an array-like object is a numpy array or a PyTorch array.
56+ Read more at https://numpy.org/devdocs/user/basics.interoperability.html
57+ """
3858
3959 def __array__ (self , dtype : Any ) -> np .ndarray : ...
4060
4161
4262@typing .runtime_checkable
4363class DLPackCompatible (Protocol ):
44- """Protocol objects that can support dlpack."""
64+ """Protocol objects that can support dlpack.
65+
66+ Computation backends can call __dlpack__ to obtain the underlying data in a
67+ tensor without copying the data. This allows use to use tensorflow tensors etc.
68+ without copying the data.
69+ """
4570
4671 def __dlpack__ (self , * , stream : Any = ...) -> Any :
4772 """Return PyCapsule."""
@@ -52,13 +77,25 @@ def __dlpack__(self, *, stream: Any = ...) -> Any:
5277class TensorProtocol (ArrayCompatible , Protocol ):
5378 """Concrete tensor backed by data.
5479
80+ The protocol does not specify how the data is stored. That data is exposed
81+ through the :attr:`raw` attribute for examination, but accessing :attr:`raw`
82+ is typically not needed.
83+
84+ To use the tensor as a numpy array, call :meth:`numpy`. To convert the tensor
85+ to a byte string for serialization, call :meth:`tobytes`.
86+
87+ It is recommended to check the size of the tensor first before accessing the
88+ underlying data, because accessing the data may be expensive and incur IO
89+ overhead.
90+
5591 Attributes:
5692 name: The name of the tensor.
5793 shape: The shape of the tensor.
58- dtype: The data type of the elements of the tensor.
94+ dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
5995 doc_string: Documentation string.
6096 raw: The raw data behind this tensor. It can be anything.
61- value: The tensor as a numpy array.
97+ size: The number of elements in the tensor.
98+ nbytes: The number of bytes in the tensor.
6299 """
63100
64101 name : str
@@ -67,6 +104,12 @@ class TensorProtocol(ArrayCompatible, Protocol):
67104 doc_string : str | None
68105 raw : Any
69106
107+ @property
108+ def size (self ) -> int : ...
109+
110+ @property
111+ def nbytes (self ) -> int : ...
112+
70113 def numpy (self ) -> np .ndarray :
71114 """Return the tensor as a numpy array."""
72115 ...
@@ -82,12 +125,27 @@ def tobytes(self) -> bytes:
82125
83126@typing .runtime_checkable
84127class ValueProtocol (Protocol ):
85- """Protocol for ONNX values.
128+ """Values.
129+
130+ A value is a named entity that can be used to represent an input or output of a graph,
131+ a function, or a node. The information it stores corresponds to ValueInfoProto
132+ in the ONNX specification.
133+
134+ A :class:`Value` is always not owned or owned by exactly one node. When the value is not
135+ owned, it must be an input of a graph or a function. The def_node and def_index
136+ attributes are None.
86137
87- A value is a named entity that can be used as an input or output of an operator.
138+ When the value is owned by a node, it is an output of the node.
139+ The node that produces the value is stored in the :attr:`def_node` attribute.
140+ The index of the output of the node that produces the value is stored in the
141+ :attr:`def_index` attribute.
142+
143+ To find all the nodes that use this value as an input, call :meth:`users`.
144+
145+ To check if the value is an output of a graph, call :meth:`is_graph_output`.
88146
89147 Attributes:
90- name: The name of the value.
148+ name: The name of the value. A value is always named when it is part of a graph.
91149 def_node: The node that produces this value.
92150 def_index: The index of the output of the node that produces this value.
93151 shape: The shape of the value.
@@ -113,39 +171,68 @@ def is_graph_output(self) -> bool:
113171
114172@typing .runtime_checkable
115173class NodeProtocol (Protocol ):
116- """Protocol for ONNX nodes.
174+ """Nodes.
175+
176+ A node represents an invocation of an operation on the :class:`Value` s in
177+ the computational graph.
117178
118- A node represents an operation in the computation graph.
179+ A node can be optionally named. A name should typically be assigned when the
180+ node is added to a graph.
181+
182+ :attr:`domain`, :attr:`op_type`, and :attr:`overload` together uniquely identify
183+ the operator, and are always strings. For ONNX operators, :attr:`domain` and :attr:`overload`
184+ are both empty strings.
185+
186+ :attr:`inputs` and :attr:`outputs` are the input and output values of the node.
187+
188+ :attr:`attributes` are the attributes of the node. The attributes are stored in an
189+ ordered dictionary to preserve the order of the attributes. This is a deviation from
190+ the current ONNX spec where attributes are unordered, but it is helpful for tools
191+ that rely on the order of the attributes, e.g. those converting to and from Python
192+ function keyword arguments.
193+
194+ :attr:`version` is unique to the IR and is not specified in the ONNX spec. This
195+ allows the IR to represent a graph with mixed opset versions. Deserializers
196+ should decide how to reconcile the different versions within the graph. A typical
197+ graph will have a single version, declared in the :class:`Graph` object and
198+ the nodes will have ``None`` as the version.
119199
120200 Attributes:
121- domain: The domain of the operator. E.g. "" for ONNX operators.
122- version: The version of the operator.
201+ domain: The domain of the operator. E.g. ``""`` for ONNX operators.
123202 op_type: The operator name.
124203 overload: The overload name when the node is invoking a function.
125204 inputs: Input values.
126205 outputs: Output values.
127206 attributes: The attributes of the operator.
207+ version: The version of the operator.
128208 doc_string: Documentation string.
129209 metadata_props: Metadata.
130210 """
131211
132212 name : str | None
133213 domain : str
134- version : int | None
135214 op_type : str
136215 overload : str
137216 inputs : Sequence [ValueProtocol ]
138217 outputs : Sequence [ValueProtocol ]
139218 attributes : OrderedDict [str , AttributeProtocol | ReferenceAttributeProtocol ]
219+ version : int | None
140220 doc_string : str | None
141221 metadata_props : Mapping [str , str ]
142222
143223
144224@typing .runtime_checkable
145225class GraphProtocol (Protocol ):
146- """Protocol for ONNX graphs.
226+ """Graphs.
227+
228+ Graph represents a computation graph. In addition to the ONNX specification
229+ specified fields, it also contains a mapping of :attr:`opset_imports`. This
230+ allows different subgraphs to import different opsets. It is the responsibility
231+ of the deserializer to reconcile the different opsets.
147232
148- Graph represents a computation graph.
233+ The :attr:`nodes` are not guaranteed to be topologically sorted. But the
234+ iteration order should be deterministic across different runs. It is the
235+ responsibility of the user to maintain a topological order of the nodes.
149236
150237 Attributes:
151238 name: The name of the graph.
@@ -158,7 +245,7 @@ class GraphProtocol(Protocol):
158245 metadata_props: Metadata.
159246 """
160247
161- # TODO(justinchuby): Support quantization_annotation and metadata_props
248+ # TODO(justinchuby): Support quantization_annotation
162249 name : str | None
163250 inputs : Sequence [ValueProtocol ]
164251 outputs : Sequence [ValueProtocol ]
@@ -175,9 +262,10 @@ def topologically_sorted_nodes(self) -> Sequence[NodeProtocol]:
175262
176263@typing .runtime_checkable
177264class ModelProtocol (Protocol ):
178- """Protocol for ONNX models .
265+ """Models .
179266
180- A model is a container for a graph and metadata.
267+ A model is a container for a graph and metadata. It is the top-level object
268+ that represents an ONNX model.
181269
182270 Attributes:
183271 graph: The graph of the model.
@@ -225,6 +313,8 @@ class AttributeProtocol(Protocol):
225313class ReferenceAttributeProtocol (Protocol ):
226314 """Protocol for a reference attribute.
227315
316+ A reference attribute can only appear inside the definition body of a function.
317+
228318 Attributes:
229319 name: The name of the attribute.
230320 ref_attr_name: The name of the attribute definition this attribute refers to.
@@ -312,6 +402,11 @@ def __eq__(self, __value: object) -> bool: ...
312402
313403@typing .runtime_checkable
314404class MapTypeProtocol (Protocol ):
405+ """Protocol for ONNX map types.
406+
407+ TODO: This protocol is not yet implemented in the ONNX IR.
408+ """
409+
315410 key_type : typing .Literal [
316411 _enums .DataType .STRING ,
317412 _enums .DataType .INT64 ,
@@ -330,6 +425,9 @@ class MapTypeProtocol(Protocol):
330425class FunctionProtocol (Protocol ):
331426 """Protocol for ONNX functions.
332427
428+ Like a graph, a function can have nodes that are not topologically sorted. It is
429+ the responsibility of the user to maintain a topological order of the nodes.
430+
333431 Attributes:
334432 name: The function name.
335433 domain: The domain this function is defined in.
@@ -350,10 +448,6 @@ class FunctionProtocol(Protocol):
350448 attributes : OrderedDict [str , AttributeProtocol ]
351449 outputs : Sequence [ValueProtocol ]
352450 doc_string : str
353- # opset_import is stored in a model, not a graph. However,
354- # In ONNX IR we store it in a graph to unify it with
355- # the function. This way a materialized function can still
356- # be used as a subgraph even if it imports a different opset.
357451 opset_imports : Mapping [str , int ]
358452 nodes : Sequence [NodeProtocol ]
359453 metadata_props : Mapping [str , str ]
0 commit comments