Skip to content

Updates to the rewriter tutorial #2397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 19, 2025
3 changes: 2 additions & 1 deletion docs/tutorial/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ ONNX perspective, the two assignments to *g* represent two distinct tensors
```{toctree}
:maxdepth: 1

optimizer/index
rewriter/index
optimizer/index

```
14 changes: 13 additions & 1 deletion docs/tutorial/rewriter/commute.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
(heading-target-commute)=
# Utilizing `commute` parameter for pattern-matching
Extending the previous [simple example](heading-target-simple), assumming a scenario where we have a graph with the following structure.

```{warning}
Please note that the section below describes a convenience feature for handling commutative operators
in pattern matching. However, the implementation is a simple, brute-force, technique that generates a collection
of rewrite-rules from a given rule, taking commutativity of addition and multiplication into account. This can
lead to an exponential increase in the number of rewrite-rules. So, it should be used with caution. Pattern
disjunctions (_OR Patterns_) described earlier can be used judiciously to get a somewhat more efficient
implementation in practice (even though the potential for exponential increase still exists within the
pattern matching algorithm). Reimplementing commutativity handling using pattern disjunctions is future
work.
```

Extending the previous [simple example](heading-target-simple), assuming a scenario where we have a graph with the following structure.

![commute](examples/img/erfgelu_03_commute.png){align=center width=500px}

Expand Down
6 changes: 5 additions & 1 deletion docs/tutorial/rewriter/conditional_rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ The target pattern in this case has 5 inputs `input_a`, `input_b`, `shape_a`, `s
Similarly for writing the condition checking function, we require only `input_a`, `input_b` and `shape_c`. Use `**_` to represent all the unused parameters in the condition matching function signature.
:::

In order to validate whether matmul broadcast is sufficient, we write a condition checking function as follows:
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below.
Note that the relevant inputs passed to the check function are all instances of :class:`onnx_ir.Value`. These represent
the values in the input graph IR that matched against the corresponding _pattern variables_ in the target
pattern. Please see documentation of the [IR API](https://onnx.ai/ir-py/) for more details on how to use it, for example to identify
the type or shape or rank of these values.

```{literalinclude} examples/broadcast_matmul.py
:pyobject: check_if_not_need_reshape
Expand Down
16 changes: 16 additions & 0 deletions docs/tutorial/rewriter/examples/erfgelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def apply_rewrite(model):
return model_with_rewrite_applied


####################################
# Rewrite Rule as a Class
# =====================


class ErfGeluFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5

def rewrite(self, op, x):
return op.Gelu(x, _domain="com.microsoft")


erf_gelu_rule_from_class = ErfGeluFusion.rule()


def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorial/rewriter/rewrite_patterns.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Introduction

The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on rewrite rules provided by the user.
The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on conditional rewrite rules provided by the user.

# Usage

Expand Down
22 changes: 16 additions & 6 deletions docs/tutorial/rewriter/simple_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,28 @@ rule = pattern.RewriteRule(
)
```

Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:
It is more convenient to organize more complex rewrite-rules as a class. The above rule can be
alternatively expressed as below.

1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`.
```{literalinclude} examples/erfgelu.py
:pyobject: ErfGeluFusion
```

The corresponding rewrite-rule can be obtained as below:

```python
erf_gelu_rule_from_class = ErfGeluFusion.rule()
```

Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite (model, pattern_rewrite_rules)` call applies the specified rewrite rules to the given model.

2. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. This parameter is of either one of these types:
- `Sequence[PatternRewriteRule]`
- `RewriteRuleSet`
1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `ir.Model` or `onnx.ModelProto`. If the model is an `ir.Model`, the rewriter applies the changes in-place, modifying the input model. If it is an `ModelProto`, the rewriter returns a new `ModelProto` representing the transformed model.
2. `pattern_rewrite_rules` : This parameter either a `Sequence[PatternRewriteRule]` or a `RewriteRuleSet`.

:::{note}
:name: pattern_rewrite_rules input formatting

`pattern_rewrite_rules` takes a sequence of `PatternRewriteRule` types or a RewriteRuleSet which is also essentially a rule set created using a sequence of `PatternRewriteRule` types, so if only a singular rewrite rule is to be passed, it needs to passed as part of a sequence. For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset).
For steps on how to create and use Rule-sets, refer to the example in the section [Creating a rule-set with different patterns](#heading-target-commute-ruleset).
:::

The snippet below below demonstrates how to use the `rewriter.rewrite` call for the rewrite rule created above:
Expand Down
1 change: 1 addition & 0 deletions pyproject_pylint.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

[tool.pylint.messages_control]
disable = [
"arguments-differ", # TODO: abstract methods in Rewriter
"attribute-defined-outside-init", # TODO: mostly in onnxscript/converter.py
"cell-var-from-loop", # Bugbear B023
"consider-using-from-import",
Expand Down
Loading