Skip to content

Commit 2c74be7

Browse files
BowenBaojustinchubyxadupregramalingamtitaiwangms
authored
Migrate onnxrewriter (#1346)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1334 * #1340 * __->__ #1346 Squashed of the following steps: - #1328 - #1329 - #1330 - #1331 - #1332 - #1333 - #1343 - #1345 Co-authored-by: Shubham Bhokare <[email protected]> Co-authored-by: Justin Chu <[email protected]> Co-authored-by: Xavier Dupré <[email protected]> Co-authored-by: "G. Ramalingam" <[email protected]> Co-authored-by: kunal-vaishnavi <[email protected]> Co-authored-by: Ti-Tai Wang <[email protected]>
1 parent e29b43a commit 2c74be7

File tree

388 files changed

+12542
-94
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

388 files changed

+12542
-94
lines changed

.gitattributes

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
**/*.pb filter=lfs diff=lfs merge=lfs -text
2+
**/*.onnx filter=lfs diff=lfs merge=lfs -text

.github/workflows/main.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ jobs:
6767
python-version: ${{ matrix.python-version }}
6868
- name: Install nox
6969
run: python -m pip install nox
70+
- name: Pull Test Data
71+
run: git lfs pull
7072
- name: Run tests
7173
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
7274
env:

.lintrunner.toml

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ include_patterns = [
88
'**/*.pyi',
99
]
1010
exclude_patterns = [
11-
'onnxscript/tests/models/**',
11+
'tests/models/**',
1212
]
1313
command = [
1414
'python',
@@ -43,9 +43,26 @@ exclude_patterns = [
4343
'onnxscript/evaluator_test.py',
4444
'onnxscript/evaluator.py',
4545
'onnxscript/onnx_types.py',
46-
'onnxscript/tests/**', # Skip linting test files for speed
46+
'tests/**', # Skip linting test files for speed
4747
'onnxscript/**/*_test.py', # Skip linting test files for speed
4848
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
49+
'onnxscript/optimizer/evaluator.py', # FIXME
50+
'onnxscript/optimizer/constant_folding.py', # FIXME
51+
'onnxscript/_legacy_ir/__init__.py', # FIXME
52+
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
53+
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
54+
'onnxscript/rewriter/function_rule.py', # FIXME
55+
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
56+
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
57+
'onnxscript/rewriter/pattern.py', # FIXME
58+
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
59+
'onnxscript/tools/function_unittest_producer.py', # FIXME
60+
'onnxscript/_legacy_ir/visitor.py', # FIXME
61+
'onnxscript/_legacy_ir/protobuilder.py', # FIXME
62+
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
63+
'onnxscript/ir/serde.py', # FIXME
64+
'onnxscript/rewriter/generic_pattern_test.py', # FIXME
65+
'onnxscript/rewriter/generic_pattern.py', # FIXME
4966
]
5067
command = [
5168
'python',
@@ -74,7 +91,7 @@ include_patterns = [
7491
'**/*.py',
7592
]
7693
exclude_patterns = [
77-
'onnxscript/tests/onnx_backend_test_code/**',
94+
'tests/onnx_backend_test_code/**',
7895
]
7996
command = [
8097
'python',
@@ -102,12 +119,16 @@ include_patterns = [
102119
'**/*.py',
103120
]
104121
exclude_patterns = [
122+
'examples/**', # TODO: Merge with docs/examples
105123
'docs/examples/**',
106124
'docs/tutorial/examples/**',
107125
'onnxscript/converter_test.py',
108-
'onnxscript/tests/functions/**',
109-
'onnxscript/tests/models/**',
110-
'onnxscript/tests/onnx_backend_test_code/**',
126+
'tests/functions/**',
127+
'tests/models/**',
128+
'tests/onnx_backend_test_code/**',
129+
'onnxscript/optimizer/**', # FIXME
130+
'onnxscript/rewriter/**', # FIXME
131+
'onnxscript/_legacy_ir/**', # FIXME
111132
]
112133
command = [
113134
'python',

examples/pattern_rewriting.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Onnx Pattern Rewriting.
2+
3+
This script shows how to define a rewriting rule based on patterns.
4+
The objective is to replace some nodes in an onnx model into another
5+
sequence of nodes but more efficient.
6+
7+
First a dummy model
8+
===================
9+
"""
10+
11+
import numpy as np
12+
import onnx
13+
import onnx.helper as oh
14+
import onnx.numpy_helper as onh
15+
16+
import onnxscript
17+
import onnxscript._legacy_ir as oir
18+
import onnxscript.rewriter.generic_pattern as org
19+
20+
21+
def get_rotary_model(bad_model=False):
22+
inputs = [
23+
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
24+
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
25+
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
26+
]
27+
nodes = [
28+
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
29+
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
30+
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
31+
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
32+
oh.make_node(
33+
"ConcatTrainingBad" if bad_model else "ConcatTraining",
34+
["_onx_transpose0", "_onx_transpose0"],
35+
["_onx_concattraining0", "_onx_concattraining1"],
36+
domain="com.microsoft",
37+
),
38+
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
39+
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
40+
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
41+
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
42+
]
43+
outputs = [
44+
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
45+
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
46+
]
47+
model = oh.make_model(
48+
oh.make_graph(
49+
nodes,
50+
"experiment",
51+
inputs,
52+
outputs,
53+
),
54+
opset_imports=[
55+
oh.make_opsetid("", 18),
56+
oh.make_opsetid("com.microsoft", 18),
57+
],
58+
)
59+
return model
60+
61+
62+
model = get_rotary_model()
63+
ir_model = oir.irbuilder.build_ir(model)
64+
65+
66+
####################################
67+
# The rewriting pattern
68+
# =====================
69+
70+
op = onnxscript.opset18
71+
msft_op = onnxscript.values.Opset("com.microsoft", 1)
72+
73+
74+
def rotary_match_pattern(x, pos_ids, axis):
75+
"""The pattern to match."""
76+
unsqueeze = op.Unsqueeze(x, axis)
77+
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)
78+
79+
matmul = op.MatMul(pos_ids, cast)
80+
transpose = op.Transpose(matmul)
81+
output, length = msft_op.ConcatTraining(transpose, transpose)
82+
83+
sin = op.Sin(output)
84+
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
85+
cos = op.Cos(output)
86+
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
87+
return cast1, cast2
88+
89+
90+
def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool:
91+
"""The validation post matching.
92+
93+
Returns True to validate the replacement,
94+
False not to apply it.
95+
96+
:param g: model
97+
:param matched_nodes: matched nodes
98+
:param added_nodes: nodes replacing the matched nodes
99+
"""
100+
del g
101+
del matched_nodes
102+
del added_nodes
103+
return True
104+
105+
106+
def rotary_apply_pattern(x, pos_ids, axis):
107+
"""The replacement pattern."""
108+
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
109+
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
110+
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
111+
return part1, part2
112+
113+
114+
###########################
115+
# The rule
116+
# ========
117+
#
118+
# The rule is easy to create.
119+
120+
121+
rule = org.make_pattern_rule(
122+
rotary_match_pattern,
123+
rotary_apply_pattern,
124+
validate_rotary_mapping,
125+
)
126+
127+
################################
128+
# ``validate_rotary_mapping`` always return True.
129+
# This argument can be ignored in that case.
130+
131+
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)
132+
133+
##########################
134+
# Let's apply it.
135+
rule.apply_to_model(ir_model)
136+
137+
138+
########################
139+
# And finally, we can generate the model.
140+
141+
opt_onx = oir.protobuilder.build_model_proto(ir_model)
142+
143+
########################
144+
# Let's see what it looks like.
145+
146+
for node in opt_onx.graph.node:
147+
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")
148+
149+
#############################
150+
# What if it fails?
151+
# =================
152+
153+
154+
model = get_rotary_model(True)
155+
ir_model = oir.irbuilder.build_ir(model)
156+
157+
rule.apply_to_model(ir_model)
158+
opt_onx = oir.protobuilder.build_model_proto(ir_model)
159+
160+
print([n.op_type for n in opt_onx.graph.node])
161+
162+
################################
163+
# The match did not happen.
164+
# Let's increase the verbosity.
165+
166+
rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10)
167+
168+
rule.apply_to_model(ir_model)
169+
170+
######################################
171+
# The logs shows every time the algorithm rejected a pattern.
172+
# We can see the following:
173+
#
174+
# ::
175+
#
176+
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
177+
# --hint--: BACKWARD: different node types
178+
# --pattern
179+
# ConcatTraining(transpose, transpose) -> (output, length)
180+
# -- model
181+
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
182+
# iteration=1
183+
# --marked-- #2
184+
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
185+
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
186+
# len(stacked)=0:[]
187+
#
188+
# Line 673 in file `generic_pattern.py`, the match was rejected.
189+
# It says while comparing two nodes in the backward direction,
190+
# node types do not match.
191+
# It also says that two nodes were actually matched.

0 commit comments

Comments
 (0)