Skip to content

Commit ed82c3b

Browse files
authored
Squeeze Reshape Identity optimization (#2083)
A recent fix to the translation of pytorch symints introduces a Squeeze=>Reshape pattern that can be optimized away. This PR introduces a rewrite-rule to do this optimization. TODO (in a separate PR): for now, this optimization needs to be explicitly invoked. This should be done by default. (But there are several other such optimizations that need to be collected and included in the default-rule list.)
1 parent 89dd454 commit ed82c3b

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,26 @@
1212
import onnxscript.rewriter.pattern as orp
1313

1414

15+
class SqueezeReshape(orp.RewriteRuleClassBase):
16+
"""Replaces ``Reshape(Squeeze(x), [-1]])`` with ``Identity(x)`` for 1D x.
17+
18+
This pattern arises from the translation of pytorch symints.
19+
"""
20+
21+
def __init__(self):
22+
super().__init__("SqueezeReshape1d", remove_nodes=False)
23+
24+
def pattern(self, op, x):
25+
return op.Reshape(op.Squeeze(x), [-1])
26+
27+
def rewrite(self, op, x: ir.Value):
28+
return op.Identity(x)
29+
30+
def check(self, context, x) -> bool:
31+
del context # Unused
32+
return ir_utils.has_rank(x, 1)
33+
34+
1535
class CastIdentity(orp.RewriteRuleAsClass):
1636
"""Replaces ``Cast(., to=to)`` by ``Identity`` if possible."""
1737

@@ -259,6 +279,7 @@ def check(cls, context, x, axes1, axes2) -> bool:
259279
transpose_identity_rule = orp.make_rewrite_rule_from_class(TransposeIdentity)
260280
transpose_transpose_rule = orp.make_rewrite_rule_from_class(TransposeTranspose)
261281
unsqueeze_unsqueeze_rule = orp.make_rewrite_rule_from_class(UnsqueezeUnsqueeze)
282+
squeeze_reshape_1d_rule = SqueezeReshape.rule()
262283

263284

264285
def llama_p0_rule_set() -> orp.RewriteRuleSet:

onnxscript/rewriter/llama_rule_sets_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,43 @@ def test_llama_p0_rule_set_slice_split(self):
452452
self.assertEqual(["Split"], [n.op_type for n in rewritten_model.graph.node])
453453
self._check_model(model_proto, rewritten_model)
454454

455+
def test_squeeze_reshape_1d_test(self):
456+
rule = llama_rule_sets.squeeze_reshape_1d_rule
457+
458+
def check(model_script, expected_count) -> None:
459+
model_proto = model_script.to_model_proto()
460+
ir_model = ir.serde.deserialize_model(model_proto)
461+
count = rule.apply_to_model(ir_model)
462+
self.assertEqual(count, expected_count)
463+
if count > 0:
464+
self.assertEqual([x.op_type for x in ir_model.graph], ["Identity"])
465+
rewritten_proto = ir.serde.serialize_model(ir_model)
466+
self._check_model(model_proto, rewritten_proto)
467+
468+
op = onnxscript.opset17
469+
470+
# input of shape [12]
471+
@onnxscript.script()
472+
def model1(X: ot.FLOAT[12]):
473+
return op.Reshape(op.Squeeze(X), [-1])
474+
475+
check(model1, 1)
476+
477+
# input of shape [1]
478+
@onnxscript.script()
479+
def model2(X: ot.FLOAT[1]):
480+
return op.Reshape(op.Squeeze(X), [-1])
481+
482+
check(model2, 1)
483+
484+
# input of shape [1, 1]
485+
# This should NOT be optimized to Identity
486+
@onnxscript.script()
487+
def model3(X: ot.FLOAT[1, 1]):
488+
return op.Reshape(op.Squeeze(X), [-1])
489+
490+
check(model3, 0)
491+
455492

456493
if __name__ == "__main__":
457494
unittest.main(verbosity=2)

onnxscript/rewriter/pattern.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,9 @@ def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> No
16271627
if commute:
16281628
rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules]))
16291629
self.rules = rules
1630+
# We call remove_unused_nodes at end of rewriting if there is any rule that does
1631+
# NOT remove nodes (immediately when it is applied)
1632+
self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules)
16301633

16311634
def _apply_to_graph_or_function(
16321635
self,
@@ -1759,6 +1762,8 @@ def apply_to_model(
17591762
count += self._apply_to_graph_or_function(
17601763
model, function, verbose=verbose, tracer=tracer
17611764
)
1765+
if self.remove_unused_nodes:
1766+
onnxscript.optimizer.remove_unused_nodes(model)
17621767
if tracer:
17631768
tracer.report()
17641769
return count

0 commit comments

Comments
 (0)