Skip to content

docs: cleanup documentation for function-based rewrites📄 #2359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ import onnxscript
onnxscript.optimizer.optimize(onnx_model)
```

For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://onnxscript.ai/tutorial/optimizer/optimize.html)
For a detailed summary of all the optimizations applied by the optimizer call, refer to the tutorial [Optimizing a Model using the Optimizer](https://microsoft.github.io/onnxscript/tutorial/optimizer/optimize.html)

### ONNX Rewriter

Expand Down Expand Up @@ -205,11 +205,7 @@ model_with_rewrite_applied = onnxscript.rewriter.rewrite(
return model_with_rewrite_applied
```

For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html)

### Function-based rewriting

This style of rewriting matches a `FUNCTION_KEYWORD` and `PACKAGE_NAME` provided by the user to an existing function within the graph and replaces it with a new function provided by the user.
For a detailed tutorial on how to create target_pattern, replacement_pattern and match_condition blocks in order to utilize the pattern-based rewriter, refer to the tutorial [Pattern-based Rewrite Using Rules](https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html)

## Development Guidelines

Expand Down
4 changes: 2 additions & 2 deletions docs/tutorial/rewriter/simple_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ rule = pattern.RewriteRule(
Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The `rewriter.rewrite` call consists of three main components:

1. `model` : The original model on which the pattern rewrite rules are to be applied. This is of type `onnx.ModelProto`.
2. `function_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on function names. Steps on how to use this parameter will be covered in a different tutorial. This parameter is of type `Sequence[type[FunctionRewriteRule]]`
3. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. For the purpose of this tutorial, we will be using only this parameter in conjunction with `model`. This parameter is of either one of these types:

2. `pattern_rewrite_rules` : `(Optional)` This parameter is used to pass rewrite rules based on a provided replacement pattern. This parameter is of either one of these types:
- `Sequence[PatternRewriteRule]`
- `RewriteRuleSet`

Expand Down
7 changes: 2 additions & 5 deletions onnxscript/rewriter/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, Sequence

import onnx

Expand All @@ -25,15 +25,12 @@
def rewrite(
model_proto: onnx.ModelProto,
/,
function_rules=None,
pattern_rules: list[pattern.RewriteRule] | None = None,
pattern_rules: Sequence[pattern.RewriteRule] | None = None,
) -> onnx.ModelProto:
"""Rewrite the model using the given rules.

Args:
model_proto: The model to rewrite.
function_rules: The function rewrite rules to apply. If None, the default rules
for onnxruntime are used.
pattern_rules: The pattern rewrite rules to apply. If None, the default rules
for onnxruntime are used.

Expand Down
9 changes: 4 additions & 5 deletions tools/ort_rewriter_profiling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,13 @@
5. Develop optimization code.
- `onnx-script/onnxscript/optimizer`: Optimizations such as constant folding, inlining, dead code elimination etc.
- `onnx-script/onnxscript/rewriter`: Pattern based fusions.
- `onnx-script/onnxscript/rewriter/onnxruntime`: Onnxruntime specific pattern based fusions.
- `onnx-script/onnxscript/rewriter/onnxruntime/transformers`: Onnxruntime specific function based fusions.
- `onnx-script/onnxscript/rewriter/ort_fusions`: Onnxruntime specific pattern based fusions.
- Use function unittest producer tool to create function fusion unittest. Example command to distill 4 unittests for function `LlamaSdpaAttention` from `llama_v2_7b` `dynamo` model. The unittest models are named with prefix `sdpa_llama2`:
```
# Under onnx-script/onnxscript/rewriter/transformers
CUDA_VISIBLE_DEVICES="3" python tools/function_unittest_producer.py --model-path ../../../tools/onnx_models/llama_v2_7b_16h/dynamo_ort_rewritten/llama_v2_7b_16h_dynamo_ort_rewritten.onnx --function LlamaSdpaAttention --output-dir ../../testing/rewriter/transformers/unittest_models/ --max-outputs 4 --name sdpa_llama2
# Under onnx-script/onnxscript/rewriter
CUDA_VISIBLE_DEVICES="3" python tools/function_unittest_producer.py --model-path ../../../tools/onnx_models/llama_v2_7b_16h/dynamo_ort_rewritten/llama_v2_7b_16h_dynamo_ort_rewritten.onnx --function LlamaSdpaAttention --output-dir ../../testing/rewriter/unittest_models/ --max-outputs 4 --name sdpa_llama2
```
- Create new testcase under `onnx-script/onnxscript/rewriter/transformers` with the generated unittest models.
- Create new testcase under `onnx-script/onnxscript/rewriter/ort_fusions` with the generated unittest models.
```python
def test_sdpa_llama2(self):
common.test_function_rewrite("sdpa_llama2", 4)
Expand Down
Loading