|
5 | 5 | from typing import Sequence, TypeVar, Union
|
6 | 6 |
|
7 | 7 | __all__ = [
|
8 |
| - # Modules |
9 | 8 | "pattern",
|
10 |
| - # Functions |
11 | 9 | "rewrite",
|
| 10 | + "RewritePass", |
12 | 11 | ]
|
13 | 12 |
|
14 | 13 | import onnx
|
15 | 14 |
|
16 | 15 | from onnxscript import ir
|
17 | 16 | from onnxscript.ir.passes.common import unused_removal
|
18 |
| -from onnxscript.rewriter import pattern |
| 17 | +from onnxscript.rewriter import ( |
| 18 | + broadcast_to_matmul, |
| 19 | + cast_constant_of_shape, |
| 20 | + collapse_slices, |
| 21 | + gemm_to_matmul_add, |
| 22 | + llama_rule_sets, |
| 23 | + no_op, |
| 24 | + pattern, |
| 25 | +) |
19 | 26 |
|
20 |
| -PatternRewriteRule = pattern.RewriteRule |
21 |
| - |
22 |
| -ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model) |
| 27 | +_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) |
| 28 | +_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( |
| 29 | + *no_op.rules.rules, # TODO: merge this rule into constant folding? |
| 30 | + *broadcast_to_matmul.rules.rules, |
| 31 | + gemm_to_matmul_add.rule, # type: ignore[has-type] |
| 32 | + *cast_constant_of_shape.rules.rules, |
| 33 | + *collapse_slices.rules.rules, |
| 34 | + *llama_rule_sets.llama_p0_rule_set().rules, |
| 35 | +) |
23 | 36 |
|
24 | 37 |
|
25 | 38 | class RewritePass(ir.passes.InPlacePass):
|
26 | 39 | def __init__(
|
27 | 40 | self,
|
28 |
| - pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (), |
| 41 | + rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet, |
| 42 | + /, |
29 | 43 | ) -> None:
|
30 |
| - if pattern_rewrite_rules: |
31 |
| - if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet): |
32 |
| - # Create a pattern rule-set using provided rules |
33 |
| - pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules) |
34 |
| - assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet) |
35 |
| - self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules |
| 44 | + super().__init__() |
| 45 | + if isinstance(rules, Sequence): |
| 46 | + if not rules: |
| 47 | + raise ValueError("rules must not be empty") |
| 48 | + # Create a pattern rule-set using provided rules |
| 49 | + rules = pattern.RewriteRuleSet(rules) |
| 50 | + assert isinstance(rules, pattern.RewriteRuleSet) |
| 51 | + self.rules: pattern.RewriteRuleSet = rules |
36 | 52 |
|
37 | 53 | def call(self, model: ir.Model) -> ir.passes.PassResult:
|
38 |
| - count = self.pattern_rewrite_rules.apply_to_model(model) |
| 54 | + count = self.rules.apply_to_model(model) |
39 | 55 | if count:
|
40 | 56 | print(f"Applied {count} of general pattern rewrite rules.")
|
41 | 57 | return ir.passes.PassResult(model, bool(count))
|
42 | 58 |
|
43 | 59 |
|
44 | 60 | def rewrite(
|
45 |
| - model: ModelProtoOrIr, |
46 |
| - pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (), |
47 |
| -) -> ModelProtoOrIr: |
| 61 | + model: _ModelProtoOrIr, |
| 62 | + pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet] |
| 63 | + | None = None, |
| 64 | +) -> _ModelProtoOrIr: |
| 65 | + """Rewrite the model using the provided pattern rewrite rules. |
| 66 | +
|
| 67 | + Unused nodes, functions, and opsets will be removed after the rewrite. |
| 68 | +
|
| 69 | + Args: |
| 70 | + model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model. |
| 71 | + pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet. |
| 72 | + If not provided, default rules will be applied. If empty, no rules will be applied |
| 73 | + and the original model will be returned. |
| 74 | +
|
| 75 | + Returns: |
| 76 | + The rewritten model as the same type as the input model. |
| 77 | + """ |
| 78 | + if pattern_rewrite_rules is None: |
| 79 | + pattern_rewrite_rules = _DEFAULT_REWRITE_RULES |
| 80 | + elif not pattern_rewrite_rules: |
| 81 | + return model |
| 82 | + |
48 | 83 | if isinstance(model, onnx.ModelProto):
|
49 | 84 | model_ir = ir.serde.deserialize_model(model)
|
50 | 85 | proto = True
|
|
0 commit comments