Skip to content

Commit 7ff9bde

Browse files
remove unused nodes call
1 parent c1afac1 commit 7ff9bde

File tree

3 files changed

+4
-20
lines changed

3 files changed

+4
-20
lines changed

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import numpy as np
66

77
import onnxscript.ir as ir
8-
from onnxscript.optimizer import remove_unused_nodes
98
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
109

1110
# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops.
@@ -169,9 +168,4 @@ def rewrite(
169168
cos_sin_cache_rules = pattern.RewriteRuleSet([_cast, _cast_const_freqs, _const_freqs, _basic])
170169

171170

172-
def fuse_cos_sin_cache(model: ir.Model, debug: bool = False) -> int:
173-
fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)
174-
count = fuse_cos_sin_cache(model, debug=debug)
175-
if count != 0:
176-
remove_unused_nodes(model)
177-
return count
171+
fuse_cos_sin_cache = _fusion_utils.apply_fusion_rules(cos_sin_cache_rules)

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
# Licensed under the MIT License.
33
from __future__ import annotations
44

5-
import onnxscript.ir as ir
6-
from onnxscript.optimizer import remove_unused_nodes
75
from onnxscript.rewriter import _fusion_utils, pattern
86

97

@@ -150,8 +148,4 @@ def rewrite(
150148
gqa_rules = pattern.RewriteRuleSet([_rule1])
151149

152150

153-
def fuse_gqa(model: ir.Model, debug: bool = False) -> int:
154-
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
155-
count = fuse_gqa(model, debug=debug)
156-
remove_unused_nodes(model)
157-
return count
151+
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)

onnxscript/rewriter/ort_fusions/skip_normalization.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def _skip_rms_normalization(op, input, skip, gamma, epsilon, stash_type):
3131
return normalized, skip_sum
3232

3333

34-
_skip_rms_rule = pattern.RewriteRule(
35-
_skip_rms_norm_pattern, _skip_rms_normalization, matcher=pattern.SimplePatternMatcher
36-
)
34+
_skip_rms_rule = pattern.RewriteRule(_skip_rms_norm_pattern, _skip_rms_normalization)
3735

3836
skip_rms_normalization_rules = [_skip_rms_rule]
3937
skip_rms_normalization_ruleset = pattern.RewriteRuleSet(skip_rms_normalization_rules)
@@ -67,9 +65,7 @@ def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type)
6765
return normalized
6866

6967

70-
_skip_layer_rule = pattern.RewriteRule(
71-
_skip_layer_norm_pattern, _skip_layer_normalization, matcher=pattern.SimplePatternMatcher
72-
)
68+
_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization)
7369

7470
skip_layer_normalization_rules = [_skip_layer_rule]
7571
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)

0 commit comments

Comments
 (0)