Skip to content

Commit 7b89760

Browse files
authored
Use onnx_ir common passes (#2420)
Update the imports to use onnx_ir instead of the alias Signed-off-by: Justin Chu <[email protected]>
1 parent b8a831e commit 7b89760

39 files changed

+95
-89
lines changed

docs/tutorial/rewriter/conditional_rewrite.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Similarly for writing the condition checking function, we require only `input_a`
3232
:::
3333

3434
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below.
35-
Note that the relevant inputs passed to the check function are all instances of :class:`onnx_ir.Value`. These represent
35+
Note that the relevant inputs passed to the check function are all instances of {py:class}`onnx_ir.Value`. These represent
3636
the values in the input graph IR that matched against the corresponding _pattern variables_ in the target
3737
pattern. Please see documentation of the [IR API](https://onnx.ai/ir-py/) for more details on how to use it, for example to identify
3838
the type or shape or rank of these values.
@@ -50,4 +50,3 @@ With all the necessary components in place, the pattern rewrite rule with the `m
5050
The final graph with the applied rewrite looks as follows:
5151

5252
![broadcast_rewrite](examples/img/broadcast_02.png){align=center}
53-

docs/tutorial/rewriter/simple_example.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ After this, create a replacement pattern that consists of the GELU onnxscript op
3333
:::{note}
3434
:name: type annotate ir.Value
3535

36-
The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value <onnxscript.ir._core.Value>` class.
36+
The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value <onnx_ir.Value>` class.
3737
:::
3838

3939

onnxscript/ir/_schemas_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import parameterized
99

1010
import onnxscript
11-
import onnxscript.testing
1211
from onnxscript import FLOAT, INT64, ir
1312
from onnxscript.ir import _schemas
1413

onnxscript/onnx_types.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from typing import ClassVar, Optional, Tuple, Union
88

99
import onnx
10+
import onnx_ir as ir
1011

11-
import onnxscript.ir
12-
13-
_DType = onnxscript.ir.DataType
12+
_DType = ir.DataType
1413
_DimType = Union[int, str, type(None)]
1514
_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)]
1615

@@ -105,95 +104,95 @@ def to_string(cls) -> str:
105104
return f"tensor({cls.__name__.lower()})"
106105

107106

108-
class FLOAT(TensorType, dtype=onnxscript.ir.DataType.FLOAT):
107+
class FLOAT(TensorType, dtype=ir.DataType.FLOAT):
109108
pass
110109

111110

112-
class UINT8(TensorType, dtype=onnxscript.ir.DataType.UINT8):
111+
class UINT8(TensorType, dtype=ir.DataType.UINT8):
113112
pass
114113

115114

116-
class INT8(TensorType, dtype=onnxscript.ir.DataType.INT8):
115+
class INT8(TensorType, dtype=ir.DataType.INT8):
117116
pass
118117

119118

120-
class UINT16(TensorType, dtype=onnxscript.ir.DataType.UINT16):
119+
class UINT16(TensorType, dtype=ir.DataType.UINT16):
121120
pass
122121

123122

124-
class INT16(TensorType, dtype=onnxscript.ir.DataType.INT16):
123+
class INT16(TensorType, dtype=ir.DataType.INT16):
125124
pass
126125

127126

128-
class INT32(TensorType, dtype=onnxscript.ir.DataType.INT32):
127+
class INT32(TensorType, dtype=ir.DataType.INT32):
129128
pass
130129

131130

132-
class INT64(TensorType, dtype=onnxscript.ir.DataType.INT64):
131+
class INT64(TensorType, dtype=ir.DataType.INT64):
133132
pass
134133

135134

136-
class STRING(TensorType, dtype=onnxscript.ir.DataType.STRING):
135+
class STRING(TensorType, dtype=ir.DataType.STRING):
137136
pass
138137

139138

140-
class BOOL(TensorType, dtype=onnxscript.ir.DataType.BOOL):
139+
class BOOL(TensorType, dtype=ir.DataType.BOOL):
141140
pass
142141

143142

144-
class FLOAT16(TensorType, dtype=onnxscript.ir.DataType.FLOAT16):
143+
class FLOAT16(TensorType, dtype=ir.DataType.FLOAT16):
145144
pass
146145

147146

148-
class DOUBLE(TensorType, dtype=onnxscript.ir.DataType.DOUBLE):
147+
class DOUBLE(TensorType, dtype=ir.DataType.DOUBLE):
149148
pass
150149

151150

152-
class UINT32(TensorType, dtype=onnxscript.ir.DataType.UINT32):
151+
class UINT32(TensorType, dtype=ir.DataType.UINT32):
153152
pass
154153

155154

156-
class UINT64(TensorType, dtype=onnxscript.ir.DataType.UINT64):
155+
class UINT64(TensorType, dtype=ir.DataType.UINT64):
157156
pass
158157

159158

160-
class COMPLEX64(TensorType, dtype=onnxscript.ir.DataType.COMPLEX64):
159+
class COMPLEX64(TensorType, dtype=ir.DataType.COMPLEX64):
161160
pass
162161

163162

164-
class COMPLEX128(TensorType, dtype=onnxscript.ir.DataType.COMPLEX128):
163+
class COMPLEX128(TensorType, dtype=ir.DataType.COMPLEX128):
165164
pass
166165

167166

168-
class BFLOAT16(TensorType, dtype=onnxscript.ir.DataType.BFLOAT16):
167+
class BFLOAT16(TensorType, dtype=ir.DataType.BFLOAT16):
169168
pass
170169

171170

172-
class FLOAT8E4M3FN(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FN):
171+
class FLOAT8E4M3FN(TensorType, dtype=ir.DataType.FLOAT8E4M3FN):
173172
pass
174173

175174

176-
class FLOAT8E4M3FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E4M3FNUZ):
175+
class FLOAT8E4M3FNUZ(TensorType, dtype=ir.DataType.FLOAT8E4M3FNUZ):
177176
pass
178177

179178

180-
class FLOAT8E5M2(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2):
179+
class FLOAT8E5M2(TensorType, dtype=ir.DataType.FLOAT8E5M2):
181180
pass
182181

183182

184-
class FLOAT8E5M2FNUZ(TensorType, dtype=onnxscript.ir.DataType.FLOAT8E5M2FNUZ):
183+
class FLOAT8E5M2FNUZ(TensorType, dtype=ir.DataType.FLOAT8E5M2FNUZ):
185184
pass
186185

187186

188-
class INT4(TensorType, dtype=onnxscript.ir.DataType.INT4):
187+
class INT4(TensorType, dtype=ir.DataType.INT4):
189188
pass
190189

191190

192-
class UINT4(TensorType, dtype=onnxscript.ir.DataType.UINT4):
191+
class UINT4(TensorType, dtype=ir.DataType.UINT4):
193192
pass
194193

195194

196-
class FLOAT4E2M1(TensorType, dtype=onnxscript.ir.DataType.FLOAT4E2M1):
195+
class FLOAT4E2M1(TensorType, dtype=ir.DataType.FLOAT4E2M1):
197196
pass
198197

199198

onnxscript/optimizer/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
]
1616

1717
import onnx
18+
import onnx_ir.passes.common as common_passes
1819

19-
import onnxscript.ir.passes.common
2020
import onnxscript.optimizer._constant_folding as constant_folding
2121
from onnxscript import ir
2222
from onnxscript.optimizer._constant_folding import (
@@ -90,7 +90,7 @@ def optimize(
9090
def inline(model: ir.Model) -> None:
9191
"""Inline all function calls (recursively) in the model."""
9292
if model.functions:
93-
onnxscript.ir.passes.common.InlinePass()(model)
93+
common_passes.InlinePass()(model)
9494

9595

9696
def fold_constants(
@@ -114,10 +114,10 @@ def fold_constants(
114114
def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
115115
"""Removes unused nodes from a model inplace."""
116116
if isinstance(model, ir.Model):
117-
onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model)
117+
common_passes.RemoveUnusedNodesPass()(model)
118118
else:
119119
model_ir = ir.serde.deserialize_model(model)
120-
model_ir = onnxscript.ir.passes.common.RemoveUnusedNodesPass()(model_ir).model
120+
model_ir = common_passes.RemoveUnusedNodesPass()(model_ir).model
121121
new_proto = ir.serde.serialize_model(model_ir)
122122
model.Clear()
123123
model.CopyFrom(new_proto)
@@ -126,10 +126,10 @@ def remove_unused_nodes(model: ir.Model | onnx.ModelProto) -> None:
126126
def remove_unused_functions(model: ir.Model | onnx.ModelProto) -> None:
127127
"""Removes unused functions from a model inplace."""
128128
if isinstance(model, ir.Model):
129-
onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model)
129+
common_passes.RemoveUnusedFunctionsPass()(model)
130130
else:
131131
model_ir = ir.serde.deserialize_model(model)
132-
model_ir = onnxscript.ir.passes.common.RemoveUnusedFunctionsPass()(model_ir).model
132+
model_ir = common_passes.RemoveUnusedFunctionsPass()(model_ir).model
133133
new_proto = ir.serde.serialize_model(model_ir)
134134
model.Clear()
135135
model.CopyFrom(new_proto)

onnxscript/optimizer/_constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import numpy as np
1515
import onnx
1616
import onnx.reference.ops
17+
import onnx_ir as ir
1718

18-
import onnxscript.ir as ir
1919
import onnxscript.utils.utils as utils
2020
from onnxscript.ir import _tape
2121

onnxscript/optimizer/_optimizer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import logging
66

7-
import onnxscript.ir.passes.common
7+
import onnx_ir.passes.common as common_passes
8+
89
from onnxscript import ir, rewriter
910
from onnxscript.optimizer import _constant_folding
1011

@@ -43,21 +44,21 @@ def optimize_ir(
4344
output_size_limit=output_size_limit,
4445
),
4546
rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES),
46-
onnxscript.ir.passes.common.RemoveUnusedNodesPass(),
47-
onnxscript.ir.passes.common.RemoveUnusedFunctionsPass(),
48-
onnxscript.ir.passes.common.RemoveUnusedOpsetsPass(),
47+
common_passes.RemoveUnusedNodesPass(),
48+
common_passes.RemoveUnusedFunctionsPass(),
49+
common_passes.RemoveUnusedOpsetsPass(),
4950
],
5051
steps=num_iterations,
5152
early_stop=stop_if_no_change,
5253
),
53-
onnxscript.ir.passes.common.RemoveUnusedNodesPass(),
54-
onnxscript.ir.passes.common.CommonSubexpressionEliminationPass(),
55-
onnxscript.ir.passes.common.LiftConstantsToInitializersPass(),
56-
onnxscript.ir.passes.common.LiftSubgraphInitializersToMainGraphPass(),
54+
common_passes.RemoveUnusedNodesPass(),
55+
common_passes.CommonSubexpressionEliminationPass(),
56+
common_passes.LiftConstantsToInitializersPass(),
57+
common_passes.LiftSubgraphInitializersToMainGraphPass(),
5758
]
5859
if inline:
5960
# Inline all functions first before optimizing
60-
passes = [onnxscript.ir.passes.common.InlinePass(), *passes]
61+
passes = [common_passes.InlinePass(), *passes]
6162
optimizer_pass = ir.passes.Sequential(*passes)
6263
assert optimizer_pass.in_place
6364
result = optimizer_pass(model)

onnxscript/optimizer/_optimizer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import unittest
55

66
import onnx
7+
import onnx_ir as ir
78

8-
import onnxscript.ir as ir
99
import onnxscript.optimizer as optimizer
1010

1111

onnxscript/rewriter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
]
1212

1313
import onnx
14+
import onnx_ir.passes.common as common_passes
1415

15-
import onnxscript.ir.passes.common as common_passes
1616
from onnxscript import ir
1717
from onnxscript.rewriter import (
1818
basic_rules,

onnxscript/rewriter/_fusion_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from typing import Callable, Sequence, Union
66

7-
import onnxscript.ir as ir
8-
import onnxscript.ir.passes.common as common_passes
7+
import onnx_ir as ir
8+
import onnx_ir.passes.common as common_passes
9+
910
from onnxscript.rewriter import pattern
1011
from onnxscript.rewriter._basics import MatchFailureError
1112

0 commit comments

Comments
 (0)