Skip to content

Commit 1ceeeca

Browse files
committed
[Migration][DO NOT MERGE] Separate old ir into _legacy_ir folder
All tests except for linter are expected to pass. ghstack-source-id: 3423e36 Pull Request resolved: #1332
1 parent 81e1110 commit 1ceeeca

22 files changed

+316
-19
lines changed

onnxscript/_legacy_ir/__init__.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
from __future__ import annotations
2+
3+
4+
import dataclasses
5+
from collections import deque
6+
from typing import List, Tuple, Union
7+
8+
import numpy as np
9+
import onnx
10+
11+
12+
class Unknown:
13+
"""A special value used to indicate that a value is not a statically known constant.
14+
15+
We use this instead of None because None is a valid constant value (since ONNX
16+
supports the Optional type).
17+
"""
18+
19+
instance = None
20+
21+
def __init__(self) -> None:
22+
if Unknown.instance is not None:
23+
raise ValueError("Unknown.instance is already set")
24+
Unknown.instance = self
25+
26+
27+
# Singleton instance of Unknown
28+
unknown = Unknown()
29+
NotConstant = unknown
30+
31+
# ConcreteValue: This type represents constant values that an ONNX variable can take.
32+
# TODO: Extend this to a recursive type to handle lists of tensors, etc., support optionals,
33+
# maps, etc.
34+
# TODO (rama): The value is sometimes stored as a numpy array, and sometimes as an ONNX TensorProto.
35+
# A uniform representation would be helpful, but we should avoid unnecessary conversions for
36+
# large tensors. Should be cleaned up in the new IR.
37+
ConcreteValue = Union[onnx.TensorProto, np.ndarray, Unknown, None]
38+
39+
# SymbolicValue: This information is used to enable partial-evaluation and specialization
40+
# of sequence operations, as well as elimination of redundant Identity ops.
41+
# The symbolic value of a variable X can be:
42+
# - a string with the value "Y", indicating that "X" is a copy of "Y"
43+
# - a list of strings, indicating that "X" is a list of tensors, with their symbolic values
44+
# Eg., the symbolic value ["A", "B", "C"] indicates that the value of X is equal to
45+
# "SequenceConstruct(A, B, C)".
46+
# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of
47+
# tensors, etc. However, we currently only handle lists of tensors.
48+
49+
SymbolicValue = Union[str, List[str]]
50+
51+
FunctionId = Tuple[str, str, str]
52+
53+
54+
def get_function_id(function: onnx.FunctionProto) -> FunctionId:
55+
return (function.domain, function.name, getattr(function, "overload", ""))
56+
57+
58+
def get_function_id_from_node(node: onnx.NodeProto) -> FunctionId:
59+
return (node.domain, node.op_type, getattr(node, "overload", ""))
60+
61+
62+
@dataclasses.dataclass
63+
class StaticValueInfo:
64+
name: str
65+
value: ConcreteValue = NotConstant
66+
type: onnx.TypeProto | None = None
67+
symbolic_value: SymbolicValue | None = None
68+
69+
def is_copy(self) -> bool:
70+
return isinstance(self.symbolic_value, str)
71+
72+
def tensor_shape_proto(self) -> onnx.TensorShapeProto | None:
73+
"""Returns the shape of a tensor or None.
74+
75+
A return value of None could mean that the type is unknown or that the type is not a tensor
76+
or that the tensor shape (that is, even the rank) is unknown.
77+
"""
78+
type = self.type
79+
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
80+
return type.tensor_type.shape
81+
return None
82+
83+
@property
84+
def shape(self) -> list[str | int | None] | None:
85+
"""Returns the shape in a list.
86+
87+
Str means that the shape is dynamic.
88+
"""
89+
type = self.type
90+
if type and type.HasField("tensor_type") and type.tensor_type.HasField("shape"):
91+
dims = []
92+
for dim in type.tensor_type.shape.dim:
93+
if dim.HasField("dim_param"):
94+
dims.append(dim.dim_param)
95+
elif dim.HasField("dim_value"):
96+
dims.append(dim.dim_value)
97+
else:
98+
dims.append(None)
99+
return dims
100+
if self.value_as_np_array is not None:
101+
return list(self.value_as_np_array.shape)
102+
return None
103+
104+
@property
105+
def element_type(self) -> int | None:
106+
"""Returns the element type of a tensor, or None if type is not known or is not a tensor."""
107+
type = self.type
108+
if type and type.HasField("tensor_type"):
109+
return type.tensor_type.elem_type
110+
return None
111+
112+
def identity_merge_from(self, other: StaticValueInfo) -> None:
113+
"""Merge the value of other into self.
114+
115+
This models the effect of an identity (copy) operation.
116+
This will update static-analysis information based on incoming value.
117+
"""
118+
if not isinstance(other, StaticValueInfo):
119+
raise TypeError(f"Cannot merge {other} into {self}.")
120+
if other.value is not NotConstant:
121+
self.value = other.value
122+
# TODO: merge and combine best shape information from both types.
123+
if other.tensor_shape_proto() is not None and other.element_type is not None:
124+
self.type = other.type
125+
# We cannot copy symbolic value across different scopes.
126+
127+
# WIP: Extensions towards new IR: Note that the default construction of StaticValueInfo
128+
# does not fill in the following fields. These fields are filled in by the IRBuilder
129+
# which constructs the IR from the ONNX model.
130+
node: Node | None = None
131+
uses: list[Node] = dataclasses.field(default_factory=list)
132+
output_index: int | None = None
133+
is_output: bool = False
134+
135+
@property
136+
def const_value(self) -> ConcreteValue:
137+
return self.value
138+
139+
@property
140+
def value_as_np_array(self) -> np.ndarray | None:
141+
if isinstance(self.value, np.ndarray):
142+
return self.value
143+
if isinstance(self.value, onnx.TensorProto):
144+
return onnx.numpy_helper.to_array(self.value)
145+
return None
146+
147+
def def_node(self) -> Node | None:
148+
return self.node
149+
150+
def def_index(self) -> int:
151+
return self.output_index
152+
153+
def is_same_as(self, other: StaticValueInfo) -> bool:
154+
"""Returns true if this value represents the same IR object as the other value.
155+
156+
This is *not* value-equality, but rather object-equality.
157+
"""
158+
return self is other
159+
160+
def __str__(self) -> str:
161+
shape = self.shape
162+
if shape is not None:
163+
shape = [str(dim) for dim in shape]
164+
shape_str = f"[{', '.join(shape)}]"
165+
else:
166+
shape_str = "None"
167+
return (
168+
f"StaticValueInfo({self.name}, shape:{shape_str}, dtype:{self.element_type}, "
169+
f"{'has const value' if self.value is not unknown else 'no const value'}.)"
170+
)
171+
172+
173+
Value = StaticValueInfo
174+
175+
176+
class Model:
177+
def __init__(self) -> None:
178+
self.gen_var_counter: int = 0
179+
180+
def set(
181+
self,
182+
model_proto: onnx.ModelProto,
183+
graph: Graph,
184+
functions: list[Function],
185+
version_map: dict[str, int],
186+
) -> None:
187+
"""TODO. This is a temporary patch."""
188+
self.original_model_proto = model_proto
189+
self.graph = graph
190+
self.functions = functions
191+
self.version_map = version_map
192+
193+
def make_new_name(self):
194+
# Temporary hack.
195+
self.gen_var_counter += 1
196+
return f"_gen_{self.gen_var_counter}"
197+
198+
def __str__(self) -> str:
199+
# TODO: Naive string representation for debugging. Need to improve this.
200+
return "\n".join(
201+
[
202+
f"ModelGraph: {self.graph}",
203+
f"Functions: {self.functions}",
204+
f"VersionMap: {self.version_map}",
205+
]
206+
)
207+
208+
209+
class Graph:
210+
def __init__(self, graph_proto: onnx.GraphProto):
211+
self.original_graph_proto = graph_proto
212+
self.nodes: deque[Node] = deque()
213+
self.values: dict[str, Value] = {}
214+
215+
@property
216+
def name(self) -> str:
217+
return self.original_graph_proto.name
218+
219+
def __str__(self) -> str:
220+
return "\n".join(
221+
[
222+
"Graph",
223+
f"Nodes: {[str(n) for n in self.nodes]}",
224+
f"Values: {[str(v) for v in self.values]}",
225+
]
226+
)
227+
228+
229+
class Function:
230+
def __init__(self, function_proto: onnx.FunctionProto):
231+
self.original_function_proto = function_proto
232+
self.nodes = deque()
233+
self.values = {}
234+
235+
@property
236+
def id(self) -> FunctionId:
237+
return (self.domain, self.name, self.overload)
238+
239+
@property
240+
def domain(self) -> str:
241+
return self.original_function_proto.domain
242+
243+
@property
244+
def name(self) -> str:
245+
return self.original_function_proto.name
246+
247+
@property
248+
def overload(self) -> str:
249+
return getattr(self.original_function_proto, "overload", "")
250+
251+
def __str__(self) -> str:
252+
return "\n".join(
253+
[
254+
"Function",
255+
f"Nodes: {[str(n) for n in self.nodes]}",
256+
f"Values: {[str(v) for v in self.values]}",
257+
]
258+
)
259+
260+
261+
class RefAttr:
262+
def __init__(self, name: str, ref_attr_name: str, type) -> None:
263+
self.name = name
264+
self.ref_attr_name = ref_attr_name
265+
self.type = type
266+
267+
def to_proto(self) -> onnx.AttributeProto:
268+
attr_proto = onnx.AttributeProto()
269+
attr_proto.name = self.name
270+
attr_proto.ref_attr_name = self.ref_attr_name
271+
attr_proto.type = self.type
272+
return attr_proto
273+
274+
275+
class Node:
276+
def __init__(self, node_proto: onnx.NodeProto) -> None:
277+
self.original_node_proto = node_proto
278+
self.domain: str = node_proto.domain
279+
self.version: int | None = None
280+
self.op_type: str = node_proto.op_type
281+
self.inputs: list[Value | None] = []
282+
self.outputs: list[Value | None] = []
283+
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}
284+
285+
def get_attribute(self, name: str) -> int | float | None:
286+
return self.attributes.get(name, None)
287+
288+
def __str__(self) -> str:
289+
return "\n".join(
290+
[
291+
"Node",
292+
f"OpType: {self.op_type}",
293+
f"Inputs: {self.inputs}",
294+
f"Outputs: {self.outputs}",
295+
f"Attributes: {self.attributes}",
296+
]
297+
)

onnxscript/ir/irbuilder.py renamed to onnxscript/_legacy_ir/irbuilder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import onnx
77

88
from onnxscript import ir
9-
from onnxscript.ir import visitor
9+
from onnxscript._legacy_ir import visitor
1010
from onnxscript.utils import utils
1111

1212
""" NOTE: IRBuilder and function visiting

onnxscript/ir/irbuilder_test.py renamed to onnxscript/_legacy_ir/irbuilder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import onnx.parser
44

5-
from onnxscript.ir import irbuilder
5+
from onnxscript._legacy_ir import irbuilder
66

77

88
class IRBuilderTest(unittest.TestCase):
File renamed without changes.

onnxscript/ir/protobuilder_test.py renamed to onnxscript/_legacy_ir/protobuilder_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import onnx.checker
55
import onnx.parser
66

7-
from onnxscript.ir import irbuilder, protobuilder
7+
from onnxscript._legacy_ir import irbuilder, protobuilder
88
from onnxscript.rewriter import pattern
99
from onnxscript.rewriter.onnxruntime import instance_to_group_normalization
1010

File renamed without changes.

onnxscript/ir/visitor_test.py renamed to onnxscript/_legacy_ir/visitor_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import onnx
44

5-
from onnxscript.ir import visitor
5+
from onnxscript._legacy_ir import visitor
66

77

88
class FunctionCallsiteProtoTransformerTest(unittest.TestCase):

onnxscript/optimizer/constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import onnx.reference.ops
99

1010
from onnxscript import ir
11-
from onnxscript.ir import visitor
11+
from onnxscript._legacy_ir import visitor
1212
from onnxscript.optimizer import evaluator
1313
from onnxscript.utils.utils import (
1414
is_control_flow_op,

onnxscript/optimizer/copy_propagation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import onnx
66

77
import onnxscript.optimizer.remove_unused
8-
from onnxscript.ir import visitor
8+
from onnxscript._legacy_ir import visitor
99
from onnxscript.utils.utils import is_onnx_op
1010

1111

onnxscript/optimizer/simple_function_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import onnx
99

1010
from onnxscript import ir
11-
from onnxscript.ir import visitor
11+
from onnxscript._legacy_ir import visitor
1212
from onnxscript.optimizer import remove_unused
1313

1414
logger = logging.getLogger(__name__)

onnxscript/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import onnx
1616

17-
from onnxscript.ir import irbuilder, protobuilder
17+
from onnxscript._legacy_ir import irbuilder, protobuilder
1818
from onnxscript.rewriter import function_rule, pattern
1919

2020
PatternRewriteRule = pattern.RewriteRule

onnxscript/rewriter/broadcast_to_matmul_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import onnx.parser
44

5-
from onnxscript.ir import irbuilder
5+
from onnxscript._legacy_ir import irbuilder
66
from onnxscript.rewriter import broadcast_to_matmul
77

88

onnxscript/rewriter/cast_constant_of_shape_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import onnx.parser
44

5-
from onnxscript.ir import irbuilder
5+
from onnxscript._legacy_ir import irbuilder
66
from onnxscript.rewriter import cast_constant_of_shape
77

88

onnxscript/rewriter/function_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from packaging import version
99

1010
from onnxscript import ir
11-
from onnxscript.ir import visitor
11+
from onnxscript._legacy_ir import visitor
1212
from onnxscript.rewriter import pattern
1313

1414
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)