-
Notifications
You must be signed in to change notification settings - Fork 72
[docs] Document rewriter pattern options #2406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
99cef5c
40ea7ac
d082599
9c2d145
aaceba9
12f0fea
2f72d62
2a6889e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Specifying variable inputs in the pattern | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
|
||
|
||
This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting. | ||
The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs | ||
beyond those specified in the pattern. If it is set to `False` (the default), then the node must | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Trailing whitespace
|
||
have exactly the specified inputs for a successful match. If set to `True`, the pattern will | ||
match nodes that have the specified inputs plus any number of additional inputs. | ||
|
||
This is particularly useful when matching operations like `Conv` that can have optional inputs | ||
(such as bias), or when creating generic patterns that should work with various input configurations. | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: conv_pattern | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: conv_replacement | ||
``` | ||
|
||
```{literalinclude} examples/allow_other_inputs.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation | ||
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Trailing whitespace
|
||
`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs | ||
in the pattern definition. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Specifying domains in the pattern | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
|
||
|
||
This section demonstrates the use of the `_domain` option in pattern-based rewriting. | ||
The `_domain` option allows you to specify which operator domain the pattern should match against, | ||
and also allows you to create replacement operations in specific domains. | ||
|
||
ONNX operators can belong to different domains: | ||
- The default ONNX domain (empty string or "ai.onnx") | ||
- Custom domains like "com.microsoft" for Microsoft-specific operations | ||
- User-defined domains for custom operations | ||
|
||
## Matching operations from a specific domain | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: custom_relu_pattern | ||
``` | ||
|
||
In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Trailing whitespace
|
||
"custom.domain" domain will be matched, not standard ONNX `Relu` operations. | ||
|
||
## Creating replacement operations in a specific domain | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: microsoft_relu_replacement | ||
``` | ||
|
||
Here, the replacement operation is created in the "com.microsoft" domain, which might | ||
provide optimized implementations of standard operations. | ||
|
||
## Complete rewrite example | ||
|
||
```{literalinclude} examples/domain_option.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
This example shows how domain-specific pattern matching can be used to migrate operations | ||
between different operator domains, such as replacing custom domain operations with | ||
standard ONNX operations or vice versa. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
Check warningCode scanning / lintrunner RUFF/format Warning documentation
Run lintrunner -a to apply this patch.
Check warningCode scanning / lintrunner RUFF-FORMAT/format Warning documentation
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with variable number of inputs | ||
|
||
This script shows how to define a rewriting rule based on patterns that | ||
can match nodes with additional inputs beyond those specified in the pattern. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
@script() | ||
def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]: | ||
# Conv with bias - has 3 inputs: input, weight, bias | ||
result = opset18.Conv(A, B, C) | ||
return result | ||
|
||
|
||
_model = original_model.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
|
||
def conv_pattern(op, input, weight): | ||
# Pattern to match Conv operations, allowing additional inputs like bias | ||
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs) | ||
# even though we only specify 2 inputs in the pattern | ||
return op.Conv(input, weight, _allow_other_inputs=True) | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def conv_replacement(op, input, weight, **_): | ||
# Replace with a custom operation in a different domain | ||
return op.OptimizedConv(input, weight, _domain="custom.domain") | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
conv_rule = pattern.RewriteRule( | ||
conv_pattern, # target pattern | ||
conv_replacement, # replacement pattern | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) | ||
Check warningCode scanning / lintrunner RUFF/W292 Warning documentation
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
Check warningCode scanning / lintrunner RUFF/format Warning documentation
Run lintrunner -a to apply this patch.
Check warningCode scanning / lintrunner RUFF-FORMAT/format Warning documentation
Run lintrunner -a to apply this patch.
|
||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with domain specification | ||
|
||
This script shows how to define a rewriting rule that targets operations | ||
from specific domains and replaces them with operations in other domains. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
@script() | ||
def original_model(A: FLOAT[2, 2]) -> FLOAT[2, 2]: | ||
# This would represent a custom operation in a specific domain | ||
# For demonstration, we'll use a standard Relu but imagine it's in a custom domain | ||
result = opset18.Relu(A) | ||
return result | ||
|
||
|
||
_model = original_model.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern | ||
# ===================== | ||
|
||
|
||
def custom_relu_pattern(op, input): | ||
# Pattern to match Relu operations from a specific domain | ||
# _domain="custom.domain" specifies we only want to match operations from this domain | ||
return op.Relu(input, _domain="custom.domain") | ||
|
||
|
||
#################################### | ||
# The replacement pattern | ||
# ===================== | ||
|
||
|
||
def standard_relu_replacement(op, input, **_): | ||
# Replace with standard ONNX Relu (default domain) | ||
return op.Relu(input) | ||
|
||
|
||
#################################### | ||
# Alternative: Replace with operation in different domain | ||
# ===================== | ||
|
||
|
||
def microsoft_relu_replacement(op, input, **_): | ||
# Replace with operation in Microsoft's domain | ||
return op.OptimizedRelu(input, _domain="com.microsoft") | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
relu_rule = pattern.RewriteRule( | ||
custom_relu_pattern, # target pattern - matches custom domain operations | ||
standard_relu_replacement, # replacement pattern - uses standard domain | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
# Note: This example is demonstrative. In practice, you would modify the model | ||
# to have operations in the custom domain before applying the rewrite. | ||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) | ||
Check warningCode scanning / lintrunner RUFF/W292 Warning documentation
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
Check warningCode scanning / lintrunner RUFF/format Warning documentation
Run lintrunner -a to apply this patch.
|
||
|
||
# Licensed under the MIT License. | ||
"""ONNX Pattern Rewriting with output specification | ||
|
||
This script shows how to define a rewriting rule that specifies | ||
the number and names of outputs from operations. | ||
""" | ||
|
||
import onnx | ||
|
||
import onnxscript | ||
from onnxscript import FLOAT, opset18, script | ||
from onnxscript.rewriter import pattern | ||
|
||
|
||
@script() | ||
def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]: | ||
# Split operation that produces 2 outputs | ||
result1, result2 = opset18.Split(A, num_outputs=2, axis=0) | ||
Check warningCode scanning / lintrunner PYLINT/W0612 Warning documentation
Unused variable 'result2' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable Check warningCode scanning / lintrunner RUFF/RUF059 Warning documentation
Unpacked variable result2 is never used.
See https://docs.astral.sh/ruff/rules/unused-unpacked-variable |
||
# We only return the first output for simplicity | ||
return result1 | ||
|
||
|
||
_model = original_model.to_model_proto() | ||
onnx.checker.check_model(_model) | ||
|
||
|
||
#################################### | ||
# The target pattern with multiple outputs | ||
# ===================== | ||
|
||
|
||
def split_pattern(op, input): | ||
# Pattern to match Split operations with 2 outputs | ||
# _outputs=2 specifies that this operation produces 2 outputs | ||
return op.Split(input, num_outputs=2, axis=0, _outputs=2) | ||
|
||
|
||
#################################### | ||
# The replacement pattern with named outputs | ||
# ===================== | ||
|
||
|
||
def custom_split_replacement(op, input, **_): | ||
# Replace with a custom split operation using named outputs | ||
# _outputs=["first_half", "second_half"] assigns names to the outputs | ||
return op.CustomSplit(input, _domain="custom.domain", _outputs=["first_half", "second_half"]) | ||
|
||
|
||
#################################### | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Alternative: Single output replacement | ||
# ===================== | ||
|
||
|
||
def identity_replacement(op, input, **_): | ||
# Replace split with identity (single output) | ||
# _outputs=1 or just omitting _outputs (default is 1) | ||
return op.Identity(input, _outputs=1) | ||
|
||
|
||
#################################### | ||
# Example with explicit output count | ||
# ===================== | ||
|
||
|
||
def triple_split_replacement(op, input, **_): | ||
# Replace with operation that produces 3 outputs | ||
# _outputs=3 specifies 3 unnamed outputs | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning documentation
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
return op.Split(input, num_outputs=3, axis=0, _outputs=3) | ||
|
||
|
||
#################################### | ||
# Create Rewrite Rule and Apply to Model | ||
# ===================== | ||
|
||
|
||
def apply_rewrite(model): | ||
# Create rewrite rules | ||
split_rule = pattern.RewriteRule( | ||
split_pattern, # target pattern - matches Split with 2 outputs | ||
custom_split_replacement, # replacement pattern - uses named outputs | ||
) | ||
# Create a Rewrite Rule Set | ||
rewrite_rule_set = pattern.RewriteRuleSet([split_rule]) | ||
# Apply rewrite | ||
model_with_rewrite = onnxscript.rewriter.rewrite( | ||
model, | ||
pattern_rewrite_rules=rewrite_rule_set, | ||
) | ||
return model_with_rewrite | ||
|
||
|
||
_model_with_rewrite = apply_rewrite(_model) | ||
onnx.checker.check_model(_model_with_rewrite) | ||
Check warningCode scanning / lintrunner RUFF/W292 Warning documentation
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Specifying outputs in the pattern | ||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning documentation
Final newline expected
|
||
|
||
This section demonstrates the use of the `_outputs` option in pattern-based rewriting. | ||
The `_outputs` option allows you to specify the number of outputs an operation produces | ||
and optionally assign names to those outputs for easier reference in replacement patterns. | ||
|
||
The `_outputs` option can be specified in two ways: | ||
- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs | ||
- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs | ||
|
||
## Matching operations with multiple outputs | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: split_pattern | ||
``` | ||
|
||
This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2` | ||
specification ensures the pattern only matches operations with this specific output count. | ||
|
||
## Creating replacement operations with named outputs | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: custom_split_replacement | ||
``` | ||
|
||
In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with | ||
descriptive names. This can make the replacement pattern more readable and maintainable. | ||
|
||
## Alternative replacement patterns | ||
|
||
### Single output replacement | ||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: identity_replacement | ||
``` | ||
|
||
### Multiple output replacement | ||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: triple_split_replacement | ||
``` | ||
|
||
## Complete rewrite example | ||
|
||
```{literalinclude} examples/outputs_option.py | ||
:pyobject: apply_rewrite | ||
``` | ||
|
||
The `_outputs` option is particularly important when: | ||
- Working with operations that have variable numbers of outputs (like `Split`) | ||
- Creating custom operations that need specific output configurations | ||
- Ensuring pattern matching precision by specifying exact output counts | ||
- Improving code readability by naming outputs in replacement patterns |
Uh oh!
There was an error while loading. Please reload this page.