Skip to content

Commit 125ce9f

Browse files
committed
[Migration][DO NOT MERGE] Support for multi output pattern matching
Source PR https://github.com/microsoft/onnx-rewriter/pull/288 Co-authored-by: Xavier Dupré <xadupreusers.noreply.github.com> Co-authored-by: Ti-Tai Wang <titaiwangmicrosoft.com> ghstack-source-id: 85ea860 Pull Request resolved: #1343
1 parent 4afe3bf commit 125ce9f

File tree

8 files changed

+2075
-45
lines changed

8 files changed

+2075
-45
lines changed

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ exclude_patterns = [
6161
'onnxscript/_legacy_ir/protobuilder.py', # FIXME
6262
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
6363
'onnxscript/ir/serde.py', # FIXME
64+
'onnxrewriter/rewriter/pattern/generic_pattern_test.py', # FIXME
6465
]
6566
command = [
6667
'python',

examples/pattern_rewriting.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Onnx Pattern Rewriting.
2+
3+
This script shows how to define a rewriting rule based on patterns.
4+
The objective is to replace some nodes in an onnx model into another
5+
sequence of nodes but more efficient.
6+
7+
First a dummy model
8+
===================
9+
"""
10+
11+
import numpy as np
12+
import onnx
13+
import onnx.helper as oh
14+
import onnx.numpy_helper as onh
15+
16+
import onnxscript
17+
import onnxscript._legacy_ir as oir
18+
import onnxscript.rewriter.generic_pattern as org
19+
20+
21+
def get_rotary_model(bad_model=False):
22+
inputs = [
23+
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
24+
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
25+
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
26+
]
27+
nodes = [
28+
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
29+
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
30+
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
31+
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
32+
oh.make_node(
33+
"ConcatTrainingBad" if bad_model else "ConcatTraining",
34+
["_onx_transpose0", "_onx_transpose0"],
35+
["_onx_concattraining0", "_onx_concattraining1"],
36+
domain="com.microsoft",
37+
),
38+
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
39+
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
40+
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
41+
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
42+
]
43+
outputs = [
44+
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
45+
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
46+
]
47+
model = oh.make_model(
48+
oh.make_graph(
49+
nodes,
50+
"experiment",
51+
inputs,
52+
outputs,
53+
),
54+
opset_imports=[
55+
oh.make_opsetid("", 18),
56+
oh.make_opsetid("com.microsoft", 18),
57+
],
58+
)
59+
return model
60+
61+
62+
model = get_rotary_model()
63+
ir_model = oir.irbuilder.build_ir(model)
64+
65+
66+
####################################
67+
# The rewriting pattern
68+
# =====================
69+
70+
op = onnxscript.opset18
71+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
72+
73+
74+
def rotary_match_pattern(x, pos_ids, axis):
75+
"""The pattern to match."""
76+
unsqueeze = op.Unsqueeze(x, axis)
77+
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
78+
79+
matmul = op.MatMul(pos_ids, cast)
80+
transpose = op.Transpose(matmul)
81+
output, length = msft_op.ConcatTraining(transpose, transpose)
82+
83+
sin = op.Sin(output)
84+
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
85+
cos = op.Cos(output)
86+
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
87+
return cast1, cast2
88+
89+
90+
def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool:
91+
"""The validation post matching.
92+
93+
Returns True to validate the replacement,
94+
False not to apply it.
95+
96+
:param g: model
97+
:param matched_nodes: matched nodes
98+
:param added_nodes: nodes replacing the matched nodes
99+
"""
100+
del g
101+
del matched_nodes
102+
del added_nodes
103+
return True
104+
105+
106+
def rotary_apply_pattern(x, pos_ids, axis):
107+
"""The replacement pattern."""
108+
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
109+
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
110+
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
111+
return part1, part2
112+
113+
114+
###########################
115+
# The rule
116+
# ========
117+
#
118+
# The rule is easy to create.
119+
120+
121+
rule = org.make_pattern_rule(
122+
rotary_match_pattern,
123+
rotary_apply_pattern,
124+
validate_rotary_mapping,
125+
)
126+
127+
################################
128+
# ``validate_rotary_mapping`` always return True.
129+
# This argument can be ignored in that case.
130+
131+
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)
132+
133+
##########################
134+
# Let's apply it.
135+
rule.apply_to_model(ir_model)
136+
137+
138+
########################
139+
# And finally, we can generate the model.
140+
141+
opt_onx = oir.protobuilder.build_model_proto(ir_model)
142+
143+
########################
144+
# Let's see what it looks like.
145+
146+
for node in opt_onx.graph.node:
147+
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
148+
149+
#############################
150+
# What if it fails?
151+
# =================
152+
153+
154+
model = get_rotary_model(True)
155+
ir_model = oir.irbuilder.build_ir(model)
156+
157+
rule.apply_to_model(ir_model)
158+
opt_onx = oir.protobuilder.build_model_proto(ir_model)
159+
160+
print([n.op_type for n in opt_onx.graph.node])
161+
162+
################################
163+
# The match did not happen.
164+
# Let's increase the verbosity.
165+
166+
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
167+
168+
rule.apply_to_model(ir_model)
169+
170+
######################################
171+
# The logs shows every time the algorithm rejected a pattern.
172+
# We can see the following:
173+
#
174+
# ::
175+
#
176+
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
177+
# --hint--: BACKWARD: different node types
178+
# --pattern
179+
# ConcatTraining(transpose, transpose) -> (output, length)
180+
# -- model
181+
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
182+
# iteration=1
183+
# --marked-- #2
184+
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
185+
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
186+
# len(stacked)=0:[]
187+
#
188+
# Line 673 in file `generic_pattern.py`, the match was rejected.
189+
# It says while comparing two nodes in the backward direction,
190+
# node types do not match.
191+
# It also says that two nodes were actually matched.

onnxscript/_legacy_ir/__init__.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ def __str__(self) -> str:
224224
]
225225
)
226226

227+
@property
228+
def input_names(self) -> list[str]:
229+
return [_.name for _ in self.original_graph_proto.input]
230+
231+
@property
232+
def output_names(self) -> list[str]:
233+
return [_.name for _ in self.original_graph_proto.output]
234+
227235

228236
class Function:
229237
def __init__(self, function_proto: onnx.FunctionProto):
@@ -272,15 +280,45 @@ def to_proto(self) -> onnx.AttributeProto:
272280

273281

274282
class Node:
275-
def __init__(self, node_proto: onnx.NodeProto) -> None:
283+
def __init__(
284+
self,
285+
node_proto: onnx.NodeProto,
286+
populate_io: bool = False,
287+
) -> None:
276288
self.original_node_proto = node_proto
277289
self.domain: str = node_proto.domain
278290
self.version: int | None = None
279291
self.op_type: str = node_proto.op_type
280-
self.inputs: list[Value | None] = []
281-
self.outputs: list[Value | None] = []
292+
if populate_io:
293+
self.inputs: list[Value | None] = [Value(i) for i in node_proto.input]
294+
self.outputs: list[Value | None] = [Value(i) for i in node_proto.output]
295+
else:
296+
self.inputs: list[Value | None] = []
297+
self.outputs: list[Value | None] = []
282298
self.attributes: dict[str, int | float | RefAttr | Graph | list[Graph]] = {}
283299

300+
def __repr__(self) -> str:
301+
return (
302+
f"{self.op_type}({','.join(self.original_node_proto.input)})"
303+
f"->{','.join(self.original_node_proto.output)}"
304+
)
305+
306+
@property
307+
def name(self) -> str:
308+
return self.original_node_proto.name
309+
310+
@property
311+
def input_names(self):
312+
return self.original_node_proto.input
313+
314+
@property
315+
def output_names(self):
316+
return self.original_node_proto.output
317+
318+
@property
319+
def attribute(self):
320+
return self.original_node_proto.attribute
321+
284322
def get_attribute(self, name: str) -> int | float | None:
285323
return self.attributes.get(name, None)
286324

onnxscript/_legacy_ir/irbuilder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ def process_initializer(self, init: onnx.TensorProto):
123123
def process_node(self, node):
124124
node_ir = ir.Node(node)
125125
self.current_graph_or_function.nodes.append(node_ir)
126-
for input in node.input:
127-
value = self.lookup(input)
126+
for name in node.input:
127+
value = self.lookup(name)
128128
node_ir.inputs.append(value)
129129
if value is not None:
130130
value.uses.append(node_ir)
131131
else:
132132
# TODO(titaiwang): Do something more than warnings?
133-
warnings.warn(f"Use of undefined variable '{input}'.", stacklevel=1)
133+
warnings.warn(f"Use of undefined variable {name!r}.", stacklevel=1)
134134
for index, output in enumerate(node.output):
135135
newvalue = ir.Value(name=output, node=node_ir, output_index=index)
136136
if self._current_function is not None:

0 commit comments

Comments
 (0)