-
Notifications
You must be signed in to change notification settings - Fork 72
Returning choice values in patterns #2284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
8d913fe
Handle returning choice values
gramalingam 0e7554c
Minor fix
gramalingam 72d41c9
Address PR feedback, add documentation
gramalingam e6e4506
Run lint
gramalingam e53708b
Merge branch 'main' into rama/pattern-output-fix
gramalingam de368a2
Address PR feedback
gramalingam dc0efc2
Lint issues
gramalingam 042be9f
Merge branch 'main' into rama/pattern-output-fix
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Specifying attributes in the pattern | ||
|
||
This section demonstrates the use of attribute values in pattern-based rewriting. | ||
First, write a target pattern and replacement pattern in a similar way to the previous examples. | ||
The example pattern below will match successfully only against Dropout nodes with the | ||
attribute value `training_mode` set to `False`. | ||
The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes | ||
not specified in the pattern. If it is set to `False`, then the node must have only the specified | ||
attribute values, and no other attributes, for a successful match. The default value for this | ||
option is `True`. | ||
|
||
```{literalinclude} examples/allow_other_attributes.py | ||
:pyobject: add_pattern | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_attributes.py | ||
:pyobject: add_replacement | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_attributes.py | ||
:pyobject: apply_rewrite | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
(heading-target-commute)= | ||
# Utilizing `commute` parameter for pattern-matching | ||
Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure. | ||
|
||
{align=center width=500px} | ||
|
||
In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. | ||
|
||
{width=330px align=left} {width=330px align=center} | ||
|
||
|
||
If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. | ||
|
||
```{literalinclude} examples/erfgelu.py | ||
:pyobject: erf_gelu_pattern | ||
``` | ||
|
||
```{image} examples/img/erfgelu_06_commute.png | ||
:alt: The resulting graph after matching. | ||
:width: 400px | ||
:align: center | ||
``` | ||
|
||
Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. | ||
|
||
(heading-target-commute-ruleset)= | ||
|
||
## 1. Creating a rule-set with different patterns. | ||
|
||
This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: | ||
|
||
```python | ||
from onnxscript.rewriter import pattern | ||
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) | ||
``` | ||
|
||
In order to apply this method to the example above, first create the two separate target patterns as follows: | ||
|
||
```{literalinclude} examples/erfgelu.py | ||
:pyobject: erf_gelu_pattern | ||
``` | ||
```{literalinclude} examples/erfgelu.py | ||
:pyobject: erf_gelu_pattern_2 | ||
``` | ||
|
||
:::{note} | ||
:name: rule-application-order-matters | ||
|
||
When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**. | ||
This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list. | ||
If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur. | ||
::: | ||
|
||
|
||
Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. | ||
|
||
```{literalinclude} examples/erfgelu.py | ||
:pyobject: apply_rewrite_with_ruleset | ||
``` | ||
|
||
## 2. Using the `commute` parameter while creating a rule. | ||
|
||
Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. | ||
|
||
```{literalinclude} examples/erfgelu.py | ||
:pyobject: apply_rewrite_with_commute | ||
``` | ||
|
||
For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: | ||
|
||
{align=center width=300px} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Using the `match_condition` parameter for pattern-matching | ||
|
||
This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. | ||
|
||
Let us consider a model which consists of the following pattern. | ||
|
||
{align=center} | ||
|
||
Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: | ||
|
||
1. Input shapes check: `input_a` and `input_b` should be broadcastable | ||
2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` | ||
|
||
If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. | ||
|
||
First, write a target pattern and replacement pattern in a similar way to the first example. | ||
|
||
```{literalinclude} examples/broadcast_matmul.py | ||
:pyobject: two_reshapes_matmul_reshape_pattern | ||
``` | ||
|
||
```{literalinclude} examples/broadcast_matmul.py | ||
:pyobject: matmul_pattern | ||
``` | ||
|
||
:::{note} | ||
:name: omitting inputs in signature | ||
|
||
The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. | ||
|
||
Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. | ||
::: | ||
|
||
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: | ||
|
||
```{literalinclude} examples/broadcast_matmul.py | ||
:pyobject: check_if_not_need_reshape | ||
``` | ||
|
||
With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. | ||
|
||
```{literalinclude} examples/broadcast_matmul.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
The final graph with the applied rewrite looks as follows: | ||
|
||
{align=center} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
|
||
"""OR-patterns. | ||
|
||
This script shows how to define a rewriting rule based on OR-patterns. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
|
||
def scaled_matmul(op, x, y, factor): | ||
xy = op.MatMul(x, y) | ||
choice1 = op.Mul(xy, factor) | ||
choice2 = op.Div(xy, factor) | ||
scaled_xy = pattern.OrValue( | ||
[choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"] | ||
) | ||
return op.Relu(scaled_xy) | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def scaled_matmul_replacement(op, x, y, factor, op_type): | ||
if op_type == "Mul": | ||
return op.MatMulMulRelu(x, y, factor, _domain="some.domain") | ||
elif op_type == "Div": | ||
return op.MatMulDivRelu(x, y, factor, _domain="some.domain") | ||
else: | ||
raise ValueError(f"Unknown operation type: {op_type}") | ||
|
||
|
||
#################################### | ||
# Rewrite Rule | ||
# ===================== | ||
def apply_rewrite(model): | ||
rule = pattern.RewriteRule( | ||
scaled_matmul, # target pattern | ||
scaled_matmul_replacement, # replacement pattern | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([rule]) | ||
return onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
|
||
|
||
@script() | ||
def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: | ||
t1 = opset18.MatMul(A, B) | ||
c = opset18.Constant(value_float=2.0) | ||
t2 = opset18.Mul(t1, c) | ||
t3 = opset18.Relu(t2) | ||
return t3 | ||
|
||
|
||
_model = original_model1.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) | ||
|
||
assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"] | ||
|
||
|
||
@script() | ||
def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: | ||
t1 = opset18.MatMul(A, B) | ||
c = opset18.Constant(value_float=2.0) | ||
t2 = opset18.Div(t1, c) | ||
t3 = opset18.Relu(t2) | ||
return t3 | ||
|
||
|
||
_model = original_model2.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) | ||
|
||
assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Rewriter Tutorials | ||
# Rewriter Tutorial | ||
|
||
```{toctree} | ||
rewrite_patterns | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# OR Patterns | ||
|
||
|
||
*Note* : This feature is work-in-progress. | ||
|
||
Consider the following pattern: | ||
|
||
```{literalinclude} examples/or_pattern.py | ||
:pyobject: scaled_matmul | ||
``` | ||
|
||
This pattern will successfully match against the sequence "MatMul => Mul => Relu" as | ||
well as the sequence "MatMul => Div => Relu". The matcher will bind the variable | ||
specified in `tag_var` (`op_type` in the above example) to a value from those | ||
listed in `tag_values` to indicate which of the alternatives was used for a | ||
successful match. We can use this in the rewrite function to determine how | ||
we want to rewrite the matched sub-graph, as illustrated by the following code: | ||
|
||
```{literalinclude} examples/or_pattern.py | ||
:pyobject: scaled_matmul_replacement | ||
``` |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.