|
| 1 | +"""Onnx Pattern Rewriting. |
| 2 | +
|
| 3 | +This script shows how to define a rewriting rule based on patterns. |
| 4 | +The objective is to replace some nodes in an onnx model into another |
| 5 | +sequence of nodes but more efficient. |
| 6 | +
|
| 7 | +First a dummy model |
| 8 | +=================== |
| 9 | +""" |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import onnx |
| 13 | +import onnx.helper as oh |
| 14 | +import onnx.numpy_helper as onh |
| 15 | + |
| 16 | +import onnxscript |
| 17 | +import onnxscript._legacy_ir as oir |
| 18 | +import onnxscript.rewriter.generic_pattern as org |
| 19 | + |
| 20 | + |
| 21 | +def get_rotary_model(bad_model=False): |
| 22 | + inputs = [ |
| 23 | + oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]), |
| 24 | + oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]), |
| 25 | + oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]), |
| 26 | + ] |
| 27 | + nodes = [ |
| 28 | + oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]), |
| 29 | + oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1), |
| 30 | + oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]), |
| 31 | + oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]), |
| 32 | + oh.make_node( |
| 33 | + "ConcatTrainingBad" if bad_model else "ConcatTraining", |
| 34 | + ["_onx_transpose0", "_onx_transpose0"], |
| 35 | + ["_onx_concattraining0", "_onx_concattraining1"], |
| 36 | + domain="com.microsoft", |
| 37 | + ), |
| 38 | + oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]), |
| 39 | + oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1), |
| 40 | + oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]), |
| 41 | + oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1), |
| 42 | + ] |
| 43 | + outputs = [ |
| 44 | + oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []), |
| 45 | + oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []), |
| 46 | + ] |
| 47 | + model = oh.make_model( |
| 48 | + oh.make_graph( |
| 49 | + nodes, |
| 50 | + "experiment", |
| 51 | + inputs, |
| 52 | + outputs, |
| 53 | + ), |
| 54 | + opset_imports=[ |
| 55 | + oh.make_opsetid("", 18), |
| 56 | + oh.make_opsetid("com.microsoft", 18), |
| 57 | + ], |
| 58 | + ) |
| 59 | + return model |
| 60 | + |
| 61 | + |
| 62 | +model = get_rotary_model() |
| 63 | +ir_model = oir.irbuilder.build_ir(model) |
| 64 | + |
| 65 | + |
| 66 | +#################################### |
| 67 | +# The rewriting pattern |
| 68 | +# ===================== |
| 69 | + |
| 70 | +op = onnxscript.opset18 |
| 71 | +msft_op = onnxscript.values.Opset("com.microsoft", 1) |
| 72 | + |
| 73 | + |
| 74 | +def rotary_match_pattern(x, pos_ids, axis): |
| 75 | + """The pattern to match.""" |
| 76 | + unsqueeze = op.Unsqueeze(x, axis) |
| 77 | + cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) |
| 78 | + |
| 79 | + matmul = op.MatMul(pos_ids, cast) |
| 80 | + transpose = op.Transpose(matmul) |
| 81 | + output, length = msft_op.ConcatTraining(transpose, transpose) |
| 82 | + |
| 83 | + sin = op.Sin(output) |
| 84 | + cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) |
| 85 | + cos = op.Cos(output) |
| 86 | + cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT) |
| 87 | + return cast1, cast2 |
| 88 | + |
| 89 | + |
| 90 | +def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool: |
| 91 | + """The validation post matching. |
| 92 | +
|
| 93 | + Returns True to validate the replacement, |
| 94 | + False not to apply it. |
| 95 | +
|
| 96 | + :param g: model |
| 97 | + :param matched_nodes: matched nodes |
| 98 | + :param added_nodes: nodes replacing the matched nodes |
| 99 | + """ |
| 100 | + del g |
| 101 | + del matched_nodes |
| 102 | + del added_nodes |
| 103 | + return True |
| 104 | + |
| 105 | + |
| 106 | +def rotary_apply_pattern(x, pos_ids, axis): |
| 107 | + """The replacement pattern.""" |
| 108 | + cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) |
| 109 | + sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) |
| 110 | + part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) |
| 111 | + return part1, part2 |
| 112 | + |
| 113 | + |
| 114 | +########################### |
| 115 | +# The rule |
| 116 | +# ======== |
| 117 | +# |
| 118 | +# The rule is easy to create. |
| 119 | + |
| 120 | + |
| 121 | +rule = org.make_pattern_rule( |
| 122 | + rotary_match_pattern, |
| 123 | + rotary_apply_pattern, |
| 124 | + validate_rotary_mapping, |
| 125 | +) |
| 126 | + |
| 127 | +################################ |
| 128 | +# ``validate_rotary_mapping`` always return True. |
| 129 | +# This argument can be ignored in that case. |
| 130 | + |
| 131 | +rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern) |
| 132 | + |
| 133 | +########################## |
| 134 | +# Let's apply it. |
| 135 | +rule.apply_to_model(ir_model) |
| 136 | + |
| 137 | + |
| 138 | +######################## |
| 139 | +# And finally, we can generate the model. |
| 140 | + |
| 141 | +opt_onx = oir.protobuilder.build_model_proto(ir_model) |
| 142 | + |
| 143 | +######################## |
| 144 | +# Let's see what it looks like. |
| 145 | + |
| 146 | +for node in opt_onx.graph.node: |
| 147 | + print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}") |
| 148 | + |
| 149 | +############################# |
| 150 | +# What if it fails? |
| 151 | +# ================= |
| 152 | + |
| 153 | + |
| 154 | +model = get_rotary_model(True) |
| 155 | +ir_model = oir.irbuilder.build_ir(model) |
| 156 | + |
| 157 | +rule.apply_to_model(ir_model) |
| 158 | +opt_onx = oir.protobuilder.build_model_proto(ir_model) |
| 159 | + |
| 160 | +print([n.op_type for n in opt_onx.graph.node]) |
| 161 | + |
| 162 | +################################ |
| 163 | +# The match did not happen. |
| 164 | +# Let's increase the verbosity. |
| 165 | + |
| 166 | +rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10) |
| 167 | + |
| 168 | +rule.apply_to_model(ir_model) |
| 169 | + |
| 170 | +###################################### |
| 171 | +# The logs shows every time the algorithm rejected a pattern. |
| 172 | +# We can see the following: |
| 173 | +# |
| 174 | +# :: |
| 175 | +# |
| 176 | +# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast |
| 177 | +# --hint--: BACKWARD: different node types |
| 178 | +# --pattern |
| 179 | +# ConcatTraining(transpose, transpose) -> (output, length) |
| 180 | +# -- model |
| 181 | +# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1) |
| 182 | +# iteration=1 |
| 183 | +# --marked-- #2 |
| 184 | +# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320] |
| 185 | +# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472] |
| 186 | +# len(stacked)=0:[] |
| 187 | +# |
| 188 | +# Line 673 in file `generic_pattern.py`, the match was rejected. |
| 189 | +# It says while comparing two nodes in the backward direction, |
| 190 | +# node types do not match. |
| 191 | +# It also says that two nodes were actually matched. |
0 commit comments