diff --git a/.gitignore b/.gitignore index 23ce89a464..3344aa7659 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,7 @@ test-output.xml # Sphinx documentation docs/_build/ +docs/sg_execution_times.rst # Jupyter Notebook .ipynb_checkpoints diff --git a/docs/ir/tensors.md b/docs/ir/tensors.md index 7b46ac2094..4e1130ba3b 100644 --- a/docs/ir/tensors.md +++ b/docs/ir/tensors.md @@ -188,7 +188,7 @@ To fully support arrays from other frameworks, it is usually a good idea to crea ```{eval-rst} .. exec_code:: - + from __future__ import annotations import ctypes from typing import Any diff --git a/docs/tutorial/rewriter/attributes.md b/docs/tutorial/rewriter/attributes.md new file mode 100644 index 0000000000..12f1834241 --- /dev/null +++ b/docs/tutorial/rewriter/attributes.md @@ -0,0 +1,22 @@ +# Specifying attributes in the pattern + +This section demonstrates the use of attribute values in pattern-based rewriting. +First, write a target pattern and replacement pattern in a similar way to the previous examples. +The example pattern below will match successfully only against Dropout nodes with the +attribute value `training_mode` set to `False`. +The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes +not specified in the pattern. If it is set to `False`, then the node must have only the specified +attribute values, and no other attributes, for a successful match. The default value for this +option is `True`. + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: add_pattern +``` + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: add_replacement +``` + +```{literalinclude} examples/allow_other_attributes.py +:pyobject: apply_rewrite +``` diff --git a/docs/tutorial/rewriter/commute.md b/docs/tutorial/rewriter/commute.md new file mode 100644 index 0000000000..38b4b178aa --- /dev/null +++ b/docs/tutorial/rewriter/commute.md @@ -0,0 +1,71 @@ +(heading-target-commute)= +# Utilizing `commute` parameter for pattern-matching +Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure. + +![commute](examples/img/erfgelu_03_commute.png){align=center width=500px} + +In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. + +![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center} + + +If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +```{image} examples/img/erfgelu_06_commute.png +:alt: The resulting graph after matching. +:width: 400px +:align: center +``` + +Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. + +(heading-target-commute-ruleset)= + +## 1. Creating a rule-set with different patterns. + +This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: + +```python +from onnxscript.rewriter import pattern +rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +``` + +In order to apply this method to the example above, first create the two separate target patterns as follows: + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern_2 +``` + +:::{note} +:name: rule-application-order-matters + +When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**. +This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list. +If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur. +::: + + +Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_ruleset +``` + +## 2. Using the `commute` parameter while creating a rule. + +Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite_with_commute +``` + +For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: + +![commute](examples/img/erfgelu_07_commute.png){align=center width=300px} diff --git a/docs/tutorial/rewriter/conditional_rewrite.md b/docs/tutorial/rewriter/conditional_rewrite.md new file mode 100644 index 0000000000..07dc7793c9 --- /dev/null +++ b/docs/tutorial/rewriter/conditional_rewrite.md @@ -0,0 +1,49 @@ +# Using the `match_condition` parameter for pattern-matching + +This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. + +Let us consider a model which consists of the following pattern. + +![target_pattern](examples/img/broadcast_01.png){align=center} + +Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: + +1. Input shapes check: `input_a` and `input_b` should be broadcastable +2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` + +If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. + +First, write a target pattern and replacement pattern in a similar way to the first example. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: two_reshapes_matmul_reshape_pattern +``` + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: matmul_pattern +``` + +:::{note} +:name: omitting inputs in signature + +The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. + +Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. +::: + +In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: check_if_not_need_reshape +``` + +With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. + +```{literalinclude} examples/broadcast_matmul.py +:pyobject: apply_rewrite +``` + +The final graph with the applied rewrite looks as follows: + +![broadcast_rewrite](examples/img/broadcast_02.png){align=center} + diff --git a/docs/tutorial/rewriter/examples/or_pattern.py b/docs/tutorial/rewriter/examples/or_pattern.py new file mode 100644 index 0000000000..0e9231cc1f --- /dev/null +++ b/docs/tutorial/rewriter/examples/or_pattern.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""OR-patterns. + +This script shows how to define a rewriting rule based on OR-patterns. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + +#################################### +# The target pattern +# ===================== + + +def scaled_matmul(op, x, y, factor): + xy = op.MatMul(x, y) + choice1 = op.Mul(xy, factor) + choice2 = op.Div(xy, factor) + scaled_xy = pattern.OrValue( + [choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"] + ) + return op.Relu(scaled_xy) + + +#################################### +# The replacement pattern +# ===================== + + +def scaled_matmul_replacement(op, x, y, factor, op_type): + if op_type == "Mul": + return op.MatMulMulRelu(x, y, factor, _domain="some.domain") + elif op_type == "Div": + return op.MatMulDivRelu(x, y, factor, _domain="some.domain") + else: + raise ValueError(f"Unknown operation type: {op_type}") + + +#################################### +# Rewrite Rule +# ===================== +def apply_rewrite(model): + rule = pattern.RewriteRule( + scaled_matmul, # target pattern + scaled_matmul_replacement, # replacement pattern + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([rule]) + return onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + + +@script() +def original_model1(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: + t1 = opset18.MatMul(A, B) + c = opset18.Constant(value_float=2.0) + t2 = opset18.Mul(t1, c) + t3 = opset18.Relu(t2) + return t3 + + +_model = original_model1.to_model_proto() +onnx.checker.check_model(_model) + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) + +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulMulRelu"] + + +@script() +def original_model2(A: FLOAT[2, 2], B: FLOAT[2, 2]) -> FLOAT[2, 2]: + t1 = opset18.MatMul(A, B) + c = opset18.Constant(value_float=2.0) + t2 = opset18.Div(t1, c) + t3 = opset18.Relu(t2) + return t3 + + +_model = original_model2.to_model_proto() +onnx.checker.check_model(_model) + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) + +assert [n.op_type for n in _model_with_rewrite.graph.node] == ["Constant", "MatMulDivRelu"] diff --git a/docs/tutorial/rewriter/index.md b/docs/tutorial/rewriter/index.md index 3b4e01e149..d86ae9a474 100644 --- a/docs/tutorial/rewriter/index.md +++ b/docs/tutorial/rewriter/index.md @@ -1,4 +1,4 @@ -# Rewriter Tutorials +# Rewriter Tutorial ```{toctree} rewrite_patterns diff --git a/docs/tutorial/rewriter/or_pattern.md b/docs/tutorial/rewriter/or_pattern.md new file mode 100644 index 0000000000..6c42112467 --- /dev/null +++ b/docs/tutorial/rewriter/or_pattern.md @@ -0,0 +1,20 @@ +# OR Patterns + +*Note* : This feature is work-in-progress. + +Consider the following pattern: + +```{literalinclude} examples/or_pattern.py +:pyobject: scaled_matmul +``` + +This pattern will successfully match against the sequence "MatMul => Mul => Relu" as +well as the sequence "MatMul => Div => Relu". The matcher will bind the variable +specified in `tag_var` (`op_type` in the above example) to a value from those +listed in `tag_values` to indicate which of the alternatives was used for a +successful match. We can use this in the rewrite function to determine how +we want to rewrite the matched sub-graph, as illustrated by the following code: + +```{literalinclude} examples/or_pattern.py +:pyobject: scaled_matmul_replacement +``` diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md index d84d6b0f40..9627dc9a39 100644 --- a/docs/tutorial/rewriter/rewrite_patterns.md +++ b/docs/tutorial/rewriter/rewrite_patterns.md @@ -1,10 +1,8 @@ -# Pattern-based Rewrite Using Rules - -## Introduction +# Introduction The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on rewrite rules provided by the user. -## Usage +# Usage There are three main components needed when rewriting patterns in the graph: @@ -12,220 +10,17 @@ There are three main components needed when rewriting patterns in the graph: 2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators. 3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied. -(heading-target-simple)= -## A Simple Example - -An simple example demonstrating the usage of this functionality using the `GELU` activation function: - -`GELU` activation function can be computed using a Gauss Error Function using the given formula: - -```{math} -\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})] -``` - -We will show how we can find a subgraph matching this computation and replace it by a call to the function. - -Firstly, include all the rewriter relevant imports. - -```python -from onnxscript.rewriter import pattern -from onnxscript import ir - -``` - -Then create a target pattern that needs to be replaced using onnxscript operators. - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern -``` - -After this, create a replacement pattern that consists of the GELU onnxscript operator. - -```{literalinclude} examples/erfgelu.py -:pyobject: gelu -``` -:::{note} -:name: type annotate ir.Value - -The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class. -::: - - -For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. - -```python -rule = pattern.RewriteRule( - erf_gelu_pattern, # Target Pattern - gelu, # Replacement Pattern -) -``` - -Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: - -1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`. -2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]` -3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types: - - `Sequence[PatternRewriteRule]` - - `RewriteRuleSet` - -:::{note} -:name: pattern_rewrite_rules input formatting - -`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset). -::: - -The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above: - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite -``` - -The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended. - -![target_pattern](examples/img/erfgelu_01.png) ![replacement_pattern](examples/img/erfgelu_02.png) - -## Specifying attributes in the pattern - -This section demonstrates the use of attribute values in pattern-based rewriting. -First, write a target pattern and replacement pattern in a similar way to the previous examples. -The example pattern below will match successfully only against Dropout nodes with the -attribute value `training_mode` set to `False`. -The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes -not specified in the pattern. If it is set to `False`, then the node must have only the specified -attribute values, and no other attributes, for a successful match. The default value for this -option is `True`. - -```{literalinclude} examples/allow_other_attributes.py -:pyobject: add_pattern -``` - -```{literalinclude} examples/allow_other_attributes.py -:pyobject: add_replacement -``` - -```{literalinclude} examples/allow_other_attributes.py -:pyobject: apply_rewrite -``` - - -(heading-target-commute)= -## Utilizing `commute` parameter for pattern-matching -Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure. - -![commute](examples/img/erfgelu_03_commute.png){align=center width=500px} - -In this graph, there exist two node pattern that constitute a `GELU` op. However, there is a subtle difference between the two. Focusing on the parent `Mul` nodes in either patterns, the order of the input values being multiplied is switched. - -![gelu_pattern_1](examples/img/erfgelu_04_commute.png){width=330px align=left} ![gelu_pattern_2](examples/img/erfgelu_05_commute.png){width=330px align=center} - - -If we utilize the same `target_pattern` created for the earlier [simple example](heading-target-simple) (shown below), only one of two `GELU` pattern will be matched. - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern +```{include} simple_example.md ``` -```{image} examples/img/erfgelu_06_commute.png -:alt: The resulting graph after matching. -:width: 400px -:align: center +```{include} attributes.md ``` -Only one of the patterns has been successfully matched and replaced by a `GELU` node. In order to rewrite both the existing patterns in the graph, there are two methods. - -(heading-target-commute-ruleset)= -### 1. Creating a rule-set with different patterns. - -This method requires creating two separate rules and packing them into either a sequence of `PatternRewriteRule`s or a `RewriteRuleSet`. Creating a `RewriteRuleSet` is the preferable option but either can be used. In order to create a `RewriteRuleSet` with multiple rules `rule1` and `rule2` for example: - -```python -from onnxscript.rewriter import pattern -rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2]) +```{include} conditional_rewrite.md ``` -In order to apply this method to the example above, first create the two separate target patterns as follows: - -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern -``` -```{literalinclude} examples/erfgelu.py -:pyobject: erf_gelu_pattern_2 -``` - -:::{note} -:name: rule-application-order-matters - -When you pass multiple rules in `pattern_rewrite_rules`, the **order in which they appear is important**. -This is because some rules may depend on patterns created or modified by earlier rules. For example, if `rule2` can only match after `rule1` has made a specific change in the model, then `rule1` must come **before** `rule2` in the list. -If you're not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur. -::: - - -Then, create two separate `PatternRewriteRule`s, one for each target pattern. Pack these rules into a `RewriteRuleSet` object and apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite_with_ruleset +```{include} or_pattern.md ``` - -### 2. Using the `commute` parameter while creating a rule. - -Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the `commute` parameter can be utilized while creating the `RewriteRuleSet`. Simply set `commute=True` in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a `RewriteRuleSet` object. Then apply rewrites by passing the created `RewriteRuleSet` for the `pattern_rewrite_rules` parameter. - -```{literalinclude} examples/erfgelu.py -:pyobject: apply_rewrite_with_commute +```{include} commute.md ``` - -For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows: - -![commute](examples/img/erfgelu_07_commute.png){align=center width=300px} - -## Using the `match_condition` parameter for pattern-matching - -This section talks about how to utilize the `match_condition` parameter. The `match_condition` parameter checks if the pattern matches the target pattern with certain constraints in consideration. - -Let us consider a model which consists of the following pattern. - -![target_pattern](examples/img/broadcast_01.png){align=center} - -Based on the [ONNX Matmul spec](https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul), onnx `Matmul` behaves like `numpy.matmul` and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don't need the reshapes. To validate this, we need to check the following: - -1. Input shapes check: `input_a` and `input_b` should be broadcastable -2. Output shape check: `shape_c` should be the same as the output shape from the `matmul(input_a, input_b)` - -If the above are true, then we don't need the reshapes and we can eliminate them using a pattern based rewrite. - -First, write a target pattern and replacement pattern in a similar way to the first example. - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: two_reshapes_matmul_reshape_pattern -``` - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: matmul_pattern -``` - -:::{note} -:name: omitting inputs in signature - -The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `shape_b`, `shape_c`. However, the replacement pattern only utilizes `input_a` and `input_b`. To avoid referencing all the unused parameters in the replacement pattern signature, pass only `input_a` and `input_b` and use `**_` to represent all the unused parameters. - -Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature. -::: - -In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows: - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: check_if_not_need_reshape -``` - -With all the necessary components in place, the pattern rewrite rule with the `match_condition` function is created and then the `rewriter.rewrite` is called to apply the rewrite. - -```{literalinclude} examples/broadcast_matmul.py -:pyobject: apply_rewrite -``` - -The final graph with the applied rewrite looks as follows: - -![broadcast_rewrite](examples/img/broadcast_02.png){align=center} - diff --git a/docs/tutorial/rewriter/simple_example.md b/docs/tutorial/rewriter/simple_example.md new file mode 100644 index 0000000000..942f0ad48f --- /dev/null +++ b/docs/tutorial/rewriter/simple_example.md @@ -0,0 +1,71 @@ +(heading-target-simple)= +# A Simple Example + +An simple example demonstrating the usage of this functionality using the `GELU` activation function: + +`GELU` activation function can be computed using a Gauss Error Function using the given formula: + +```{math} +\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})] +``` + +We will show how we can find a subgraph matching this computation and replace it by a call to the function. + +Firstly, include all the rewriter relevant imports. + +```python +from onnxscript.rewriter import pattern +from onnxscript import ir + +``` + +Then create a target pattern that needs to be replaced using onnxscript operators. + +```{literalinclude} examples/erfgelu.py +:pyobject: erf_gelu_pattern +``` + +After this, create a replacement pattern that consists of the GELU onnxscript operator. + +```{literalinclude} examples/erfgelu.py +:pyobject: gelu +``` +:::{note} +:name: type annotate ir.Value + +The inputs to the replacement pattern are of type `ir.Value`. For detailed usage of `ir.Value` refer to the {py:class}`ir.Value ` class. +::: + + +For this example, we do not require a `match_condition` so that option is skipped for now. Then the rewrite rule is created using the `RewriteRule` function. + +```python +rule = pattern.RewriteRule( + erf_gelu_pattern, # Target Pattern + gelu, # Replacement Pattern +) +``` + +Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components: + +1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`. +2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]` +3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types: + - `Sequence[PatternRewriteRule]` + - `RewriteRuleSet` + +:::{note} +:name: pattern_rewrite_rules input formatting + +`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset). +::: + +The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above: + +```{literalinclude} examples/erfgelu.py +:pyobject: apply_rewrite +``` + +The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended. + +![target_pattern](examples/img/erfgelu_01.png) ![replacement_pattern](examples/img/erfgelu_02.png) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 4815e0a2b4..b78ba367ea 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -379,6 +379,26 @@ def add_node(self, node: ir.Node) -> None: """Adds a node to the list of matched nodes.""" self._current_match.add_node(node) + def bind_value(self, pattern_value: ValuePattern, value: Any) -> bool: + var_name = pattern_value.name + # TODO(rama): Simplify the following. We currently bind values to + # pattern variables in two different ways: via their name, or via the + # pattern-value itself. + if var_name is None: + for match in self._partial_matches: + if pattern_value in match.value_bindings: + # TODO(rama): Use appropriate equality-check here. + if match.value_bindings[pattern_value] == value: + return True + self._current_match.fail( + f"Binding failure: {pattern_value} bound to two different values.", + [match.value_bindings[pattern_value], value], + ) + return False + self._current_match.value_bindings[pattern_value] = value + return True + return self.bind(var_name, value) + def bind(self, var: str, value: Any) -> bool: for match in self._partial_matches: if var in match.bindings: @@ -400,6 +420,13 @@ def bindings(self) -> dict[str, Any]: raise ValueError("Bindings can be accessed only at the top-level match.") return self._current_match.bindings + @property + def value_bindings(self) -> dict[ValuePattern, ir.Value]: + """Returns the bindings for the value variables.""" + if len(self._partial_matches) > 1: + raise ValueError("Value bindings can be accessed only at the top-level match.") + return self._current_match.value_bindings + @property def outputs(self) -> MutableSequence[ir.Value]: """Returns the list of output values that matched the pattern.""" @@ -437,7 +464,9 @@ def __init__(self) -> None: # For a successful match, bindings is a dictionary of mapping pattern-variable-names # to values. self._bindings: dict[str, Any] = {} + self._value_bindings: dict[ValuePattern, ir.Value] = {} self._node_bindings: dict[NodePattern, ir.Node] = {} + self._outputs: list[ir.Value] = [] # For a failed match, _reason is a string that describes the reason for the failure. self._reason: str = "" @@ -472,24 +501,14 @@ def add_node(self, node: ir.Node) -> None: """Adds a node to the list of matched nodes.""" self._matched_nodes.append(node) - def bind(self, var: str, value: Any) -> bool: - """Binds a pattern variable name to a value from the matched IR. - - Returns True if the binding is successful, False otherwise (when the binding is inconsistent). - """ - if var in self._bindings: - # TODO(rama): Use appropriate equality-check here. - if self._bindings[var] == value: - return True - self._success = False - return False - self._bindings[var] = value - return True - @property def bindings(self) -> dict[str, Any]: return self._bindings + @property + def value_bindings(self) -> dict[ValuePattern, ir.Value]: + return self._value_bindings + @property def outputs(self) -> MutableSequence[ir.Value]: return self._outputs @@ -954,7 +973,11 @@ def visit(value_patterns: Sequence[ValuePattern | None]) -> None: return node_patterns -def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> None: +def _add_backward_slice( + node: NodePattern, + backward_slice: set[NodePattern], + backward_slice_values: set[ValuePattern], +) -> None: """Adds all nodes in the backward slice of given node to the set `backward_slice`. The backward slice of a node is the set of all nodes that are reachable from the node @@ -965,7 +988,11 @@ def _add_backward_slice(node: NodePattern, backward_slice: set[NodePattern]) -> backward_slice.add(node) for value_pattern in node.inputs: if isinstance(value_pattern, NodeOutputPattern): - _add_backward_slice(value_pattern.producer(), backward_slice) + _add_backward_slice( + value_pattern.producer(), backward_slice, backward_slice_values + ) + elif isinstance(value_pattern, (_OpIdDispatchOr, _BacktrackingOr)): + backward_slice_values.add(value_pattern) class GraphPattern: @@ -987,20 +1014,26 @@ def __init__( # whose backward-slices cover the entire pattern. output_nodes: set[NodePattern] = set() covered: set[NodePattern] = set() + choice_values_returned: set[ValuePattern] = set() + covered_choice_values: set[ValuePattern] = set() for value_pattern in outputs: if not isinstance(value_pattern, ValuePattern): raise TypeError( f"Invalid type {type(value_pattern)} for graph pattern output." ) - if isinstance(value_pattern, Constant): - raise NotImplementedError( - "Constant values are not allowed as graph pattern outputs." - ) if isinstance(value_pattern, NodeOutputPattern): candidate = value_pattern.producer() if candidate not in covered: output_nodes.add(candidate) - _add_backward_slice(candidate, covered) + _add_backward_slice(candidate, covered, covered_choice_values) + elif isinstance(value_pattern, (_OpIdDispatchOr, _BacktrackingOr)): + choice_values_returned.add(value_pattern) + + # check if all choice_values_returned are contained in covered_choice_values: + # We don't yet support the use of a choice-value as a "root" of the search. + # This is a limitation of the current implementation, and will be fixed in the future. + if not (choice_values_returned <= covered_choice_values): + raise NotImplementedError("Returning uncovered choice-values is not supported.") self.output_nodes: list[NodePattern] = list(output_nodes) @@ -1322,23 +1355,17 @@ def _match_node(self, pattern_node: NodePattern, node: ir.Node) -> bool: return False for i, output_value_pattern in enumerate(pattern_node.outputs): - if not self._bind_value(output_value_pattern, node.outputs[i]): + if not self._match.bind_value(output_value_pattern, node.outputs[i]): return False return True - def _bind_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool: - """Bind a ValuePattern var to ir Value.""" - if pattern_value.name is not None: - return self._match.bind(pattern_value.name, value) - return True - def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> bool: """Match an IR value against a ValuePattern instance.""" if isinstance(pattern_value, AnyValue): return True - if not self._bind_value(pattern_value, value): + if not self._match.bind_value(pattern_value, value): return False if isinstance(pattern_value, NodeOutputPattern): @@ -1402,16 +1429,11 @@ def _get_output_values(self) -> list[ir.Value] | None: output_values.append(self._match.bindings[value_pattern.name]) else: unbound_values.append(value_pattern.name) - elif isinstance(value_pattern, NodeOutputPattern): - i = value_pattern.output_index - node = value_pattern.producer() - matched_node = self._match.lookup_node(node) - if matched_node is not None: - output_values.append(matched_node.outputs[i]) + else: + if value_pattern in self._match.value_bindings: + output_values.append(self._match.value_bindings[value_pattern]) else: unbound_values.append(f"output_{j}") - elif isinstance(value_pattern, Constant): - raise NotImplementedError("Constant values as return-values not supported.") if unbound_values: self._match.fail(f"Error: Output values not found: {unbound_values}") return None diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index edfff6bc13..6706eea193 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -754,6 +754,31 @@ def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16 self.assertEqual([x.op_type for x in model.graph], ["GemmRelu"]) self.assertEqual([x.name for x in model.graph.node(0).inputs], ["x", "y", "bias"]) + def test_or_pattern_return_value(self): + """Test that an OrValue can be used as a return value from the source pattern.""" + + def source_pattern(op, x, y): + choice1 = op.Add(x, y) + choice2 = op.Mul(x, y) + t = pattern.OrValue([choice1, choice2]) + z = op.Relu(t) + return z, t + + def replacement(op, x, y): + z, t = op.ReluPlus(x, y, _outputs=2) + return z, t + + rule = pattern.RewriteRule(source_pattern, replacement) + + @script() + def test_model1(x: FLOAT[16, 32], y: FLOAT[16, 32]) -> FLOAT[16, 32]: + return op.Relu(op.Add(x, y)) + + model_proto = test_model1.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + rule.apply_to_model(model) + self.assertEqual([x.op_type for x in model.graph], ["ReluPlus"]) + class PatternBuilderTest(unittest.TestCase): def test_pattern_builder_context(self):