Skip to content

[docs] Document rewriter pattern options #2406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/tutorial/rewriter/allow_other_inputs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Specifying variable inputs in the pattern

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting.
The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs
beyond those specified in the pattern. If it is set to `False` (the default), then the node must

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Trailing whitespace
have exactly the specified inputs for a successful match. If set to `True`, the pattern will
match nodes that have the specified inputs plus any number of additional inputs.

This is particularly useful when matching operations like `Conv` that can have optional inputs
(such as bias), or when creating generic patterns that should work with various input configurations.

```{literalinclude} examples/allow_other_inputs.py
:pyobject: conv_pattern
```

```{literalinclude} examples/allow_other_inputs.py
:pyobject: conv_replacement
```

```{literalinclude} examples/allow_other_inputs.py
:pyobject: apply_rewrite
```

In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Trailing whitespace
`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs
in the pattern definition.
1 change: 1 addition & 0 deletions docs/tutorial/rewriter/attributes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This section demonstrates the use of attribute values in pattern-based rewriting
First, write a target pattern and replacement pattern in a similar way to the previous examples.
The example pattern below will match successfully only against Dropout nodes with the
attribute value `training_mode` set to `False`.

The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes
not specified in the pattern. If it is set to `False`, then the node must have only the specified
attribute values, and no other attributes, for a successful match. The default value for this
Expand Down
38 changes: 38 additions & 0 deletions docs/tutorial/rewriter/domain_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Specifying domains in the pattern

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

This section demonstrates the use of the `_domain` option in pattern-based rewriting.
The `_domain` option allows you to specify which operator domain the pattern should match against,
and also allows you to create replacement operations in specific domains.

ONNX operators can belong to different domains:
- The default ONNX domain (empty string or "ai.onnx")
- Custom domains like "com.microsoft" for Microsoft-specific operations
- User-defined domains for custom operations

## Matching operations from a specific domain

```{literalinclude} examples/domain_option.py
:pyobject: custom_relu_pattern
```

In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Trailing whitespace
"custom.domain" domain will be matched, not standard ONNX `Relu` operations.

## Creating replacement operations in a specific domain

```{literalinclude} examples/domain_option.py
:pyobject: microsoft_relu_replacement
```

Here, the replacement operation is created in the "com.microsoft" domain, which might
provide optimized implementations of standard operations.

## Complete rewrite example

```{literalinclude} examples/domain_option.py
:pyobject: apply_rewrite
```

This example shows how domain-specific pattern matching can be used to migrate operations
between different operator domains, such as replacing custom domain operations with
standard ONNX operations or vice versa.
71 changes: 71 additions & 0 deletions docs/tutorial/rewriter/examples/allow_other_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

Check warning

Code scanning / lintrunner

RUFF/format Warning documentation

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning documentation

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with variable number of inputs

This script shows how to define a rewriting rule based on patterns that
can match nodes with additional inputs beyond those specified in the pattern.
"""

import onnx

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]:
# Conv with bias - has 3 inputs: input, weight, bias
result = opset18.Conv(A, B, C)
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================


def conv_pattern(op, input, weight):
# Pattern to match Conv operations, allowing additional inputs like bias
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs)
# even though we only specify 2 inputs in the pattern
return op.Conv(input, weight, _allow_other_inputs=True)


####################################
# The replacement pattern
# =====================


def conv_replacement(op, input, weight, **_):
# Replace with a custom operation in a different domain
return op.OptimizedConv(input, weight, _domain="custom.domain")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
conv_rule = pattern.RewriteRule(
conv_pattern, # target pattern
conv_replacement, # replacement pattern
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)

Check warning

Code scanning / lintrunner

RUFF/W292 Warning documentation

83 changes: 83 additions & 0 deletions docs/tutorial/rewriter/examples/domain_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

Check warning

Code scanning / lintrunner

RUFF/format Warning documentation

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning documentation

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with domain specification

This script shows how to define a rewriting rule that targets operations
from specific domains and replaces them with operations in other domains.
"""

import onnx

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(A: FLOAT[2, 2]) -> FLOAT[2, 2]:
# This would represent a custom operation in a specific domain
# For demonstration, we'll use a standard Relu but imagine it's in a custom domain
result = opset18.Relu(A)
return result


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern
# =====================


def custom_relu_pattern(op, input):
# Pattern to match Relu operations from a specific domain
# _domain="custom.domain" specifies we only want to match operations from this domain
return op.Relu(input, _domain="custom.domain")


####################################
# The replacement pattern
# =====================


def standard_relu_replacement(op, input, **_):
# Replace with standard ONNX Relu (default domain)
return op.Relu(input)


####################################
# Alternative: Replace with operation in different domain
# =====================


def microsoft_relu_replacement(op, input, **_):
# Replace with operation in Microsoft's domain
return op.OptimizedRelu(input, _domain="com.microsoft")


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
relu_rule = pattern.RewriteRule(
custom_relu_pattern, # target pattern - matches custom domain operations
standard_relu_replacement, # replacement pattern - uses standard domain
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


# Note: This example is demonstrative. In practice, you would modify the model
# to have operations in the custom domain before applying the rewrite.
_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)

Check warning

Code scanning / lintrunner

RUFF/W292 Warning documentation

94 changes: 94 additions & 0 deletions docs/tutorial/rewriter/examples/outputs_option.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

Check warning

Code scanning / lintrunner

RUFF/format Warning documentation

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
"""ONNX Pattern Rewriting with output specification

This script shows how to define a rewriting rule that specifies
the number and names of outputs from operations.
"""

import onnx

import onnxscript
from onnxscript import FLOAT, opset18, script
from onnxscript.rewriter import pattern


@script()
def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]:
# Split operation that produces 2 outputs
result1, result2 = opset18.Split(A, num_outputs=2, axis=0)

Check warning

Code scanning / lintrunner

PYLINT/W0612 Warning documentation

Unused variable 'result2' (unused-variable)
See unused-variable. To disable, use # pylint: disable=unused-variable

Check warning

Code scanning / lintrunner

RUFF/RUF059 Warning documentation

Unpacked variable result2 is never used.
See https://docs.astral.sh/ruff/rules/unused-unpacked-variable
# We only return the first output for simplicity
return result1


_model = original_model.to_model_proto()
onnx.checker.check_model(_model)


####################################
# The target pattern with multiple outputs
# =====================


def split_pattern(op, input):
# Pattern to match Split operations with 2 outputs
# _outputs=2 specifies that this operation produces 2 outputs
return op.Split(input, num_outputs=2, axis=0, _outputs=2)


####################################
# The replacement pattern with named outputs
# =====================


def custom_split_replacement(op, input, **_):
# Replace with a custom split operation using named outputs
# _outputs=["first_half", "second_half"] assigns names to the outputs
return op.CustomSplit(input, _domain="custom.domain", _outputs=["first_half", "second_half"])


####################################
# Alternative: Single output replacement
# =====================


def identity_replacement(op, input, **_):
# Replace split with identity (single output)
# _outputs=1 or just omitting _outputs (default is 1)
return op.Identity(input, _outputs=1)


####################################
# Example with explicit output count
# =====================


def triple_split_replacement(op, input, **_):
# Replace with operation that produces 3 outputs
# _outputs=3 specifies 3 unnamed outputs

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning documentation

return op.Split(input, num_outputs=3, axis=0, _outputs=3)


####################################
# Create Rewrite Rule and Apply to Model
# =====================


def apply_rewrite(model):
# Create rewrite rules
split_rule = pattern.RewriteRule(
split_pattern, # target pattern - matches Split with 2 outputs
custom_split_replacement, # replacement pattern - uses named outputs
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([split_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite


_model_with_rewrite = apply_rewrite(_model)
onnx.checker.check_model(_model_with_rewrite)

Check warning

Code scanning / lintrunner

RUFF/W292 Warning documentation

51 changes: 51 additions & 0 deletions docs/tutorial/rewriter/outputs_option.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Specifying outputs in the pattern

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning documentation

Final newline expected

This section demonstrates the use of the `_outputs` option in pattern-based rewriting.
The `_outputs` option allows you to specify the number of outputs an operation produces
and optionally assign names to those outputs for easier reference in replacement patterns.

The `_outputs` option can be specified in two ways:
- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs
- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs

## Matching operations with multiple outputs

```{literalinclude} examples/outputs_option.py
:pyobject: split_pattern
```

This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2`
specification ensures the pattern only matches operations with this specific output count.

## Creating replacement operations with named outputs

```{literalinclude} examples/outputs_option.py
:pyobject: custom_split_replacement
```

In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with
descriptive names. This can make the replacement pattern more readable and maintainable.

## Alternative replacement patterns

### Single output replacement
```{literalinclude} examples/outputs_option.py
:pyobject: identity_replacement
```

### Multiple output replacement
```{literalinclude} examples/outputs_option.py
:pyobject: triple_split_replacement
```

## Complete rewrite example

```{literalinclude} examples/outputs_option.py
:pyobject: apply_rewrite
```

The `_outputs` option is particularly important when:
- Working with operations that have variable numbers of outputs (like `Split`)
- Creating custom operations that need specific output configurations
- Ensuring pattern matching precision by specifying exact output counts
- Improving code readability by naming outputs in replacement patterns
Loading
Loading