Skip to content

Commit 510fc28

Browse files
authored
Add support for a non-backtracking version of pattern disjunction (#2242)
Several fusions need to support multiple variants of a pattern (such as the optional presence of an Add or some such op). This PR adds support for a non-backtracking version of pattern disjunction. We can now use an "Or" between variants such as "Add(...)" and "MatMul(...)", for example. Supporting unrestricted Or patterns is more complicated, since failure of one alternative will require backtracking, which will require unbinding any bindings added during the unsuccessful partial search. (We can consider that later, if it seems useful.)
1 parent b63ba43 commit 510fc28

File tree

3 files changed

+131
-8
lines changed

3 files changed

+131
-8
lines changed

docs/api/rewriter_pattern.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
rewriter.pattern.NodeOutputPattern
2626
rewriter.pattern.AnyValue
2727
rewriter.pattern.Constant
28+
rewriter.pattern.OrValue
2829
rewriter.pattern.GraphPattern
2930
rewriter.pattern.ReplacementSubgraph
3031
rewriter.pattern.ReplacementPatternFunction

onnxscript/rewriter/pattern.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
MutableSequence,
1919
Protocol,
2020
Sequence,
21-
Tuple,
2221
TypeVar,
2322
Union,
2423
)
@@ -511,7 +510,7 @@ def __init__(
511510
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
512511
# TODO(rama): support overloaded operators.
513512
overload = ""
514-
self._op_identifier: tuple[str, str, str] | None = (
513+
self._op_identifier: ir.OperatorIdentifier | None = (
515514
domain.value(),
516515
op,
517516
overload,
@@ -535,7 +534,7 @@ def __str__(self) -> str:
535534
inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs
536535
return f"{outputs} = {qualified_op} ({inputs_and_attributes})"
537536

538-
def op_identifier(self) -> Tuple[str, str, str] | None:
537+
def op_identifier(self) -> ir.OperatorIdentifier | None:
539538
return self._op_identifier
540539

541540
@property
@@ -629,11 +628,6 @@ def producer(self) -> NodePattern:
629628
Var = ValuePattern
630629

631630

632-
def _is_pattern_variable(x: Any) -> bool:
633-
# The derived classes of ValuePattern represent constant patterns and node-output patterns.
634-
return type(x) is ValuePattern
635-
636-
637631
class AnyValue(ValuePattern):
638632
"""Represents a pattern that matches against any value."""
639633

@@ -718,6 +712,92 @@ def __str__(self) -> str:
718712
return str(self._value)
719713

720714

715+
class OrValue(ValuePattern):
716+
"""Represents a (restricted) form of value pattern disjunction."""
717+
718+
def __init__(
719+
self,
720+
values: Sequence[ValuePattern],
721+
name: str | None = None,
722+
tag_var: str | None = None,
723+
tag_values: Sequence[Any] | None = None,
724+
) -> None:
725+
"""
726+
Initialize an OrValue pattern.
727+
728+
Args:
729+
values: A sequence of value patterns to match against.
730+
Must contain at least two alternatives. All value patterns except the last one
731+
must have a unique producer id. This allows the pattern-matching to be deterministic,
732+
without the need for backtracking.
733+
name: An optional variable name for the pattern. Defaults to None. If present,
734+
this name will be bound to the value matched by the pattern.
735+
tag_var: An optional variable name for the tag. Defaults to None. If present,
736+
it will be bound to a value (from tag_values) indicating which alternative was matched.
737+
tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
738+
If present, the length of tag_values must match the number of alternatives in values.
739+
In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
740+
alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
741+
"""
742+
super().__init__(name)
743+
if len(values) < 2:
744+
raise ValueError("OrValue must have at least two alternatives.")
745+
if tag_values is not None:
746+
if tag_var is None:
747+
raise ValueError("tag_var must be specified if tag_values is provided.")
748+
if len(tag_values) != len(values):
749+
raise ValueError(
750+
"tag_values must have the same length as the number of alternatives."
751+
)
752+
else:
753+
tag_values = tuple(range(len(values)))
754+
self._tag_var = tag_var
755+
self._tag_values = tag_values
756+
self._values = values
757+
758+
mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {}
759+
for i, alternative in enumerate(values[:-1]):
760+
if not isinstance(alternative, NodeOutputPattern):
761+
raise TypeError(
762+
f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern."
763+
)
764+
producer = alternative.producer()
765+
id = producer.op_identifier()
766+
if id is None:
767+
raise ValueError(
768+
f"Invalid producer {producer} for OrValue. Expected a NodePattern with op identifier."
769+
)
770+
if id in mapping:
771+
raise ValueError(
772+
f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative."
773+
)
774+
mapping[id] = (tag_values[i], alternative)
775+
self._op_to_pattern = mapping
776+
self._default_pattern = (tag_values[-1], values[-1])
777+
778+
@property
779+
def tag_var(self) -> str | None:
780+
"""Returns the tag variable associated with the OrValue pattern."""
781+
return self._tag_var
782+
783+
def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue:
784+
return OrValue(
785+
[v.clone(node_map) for v in self._values],
786+
self.name,
787+
self._tag_var,
788+
self._tag_values,
789+
)
790+
791+
def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]:
792+
"""Returns the pattern that should be tried for the given value."""
793+
producer = value.producer()
794+
if producer is not None:
795+
id = producer.op_identifier()
796+
if id is not None and id in self._op_to_pattern:
797+
return self._op_to_pattern[id]
798+
return self._default_pattern
799+
800+
721801
def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]:
722802
"""Returns all nodes used in a pattern, given the outputs of the pattern."""
723803
node_patterns: list[NodePattern] = []
@@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b
11361216
if value is None:
11371217
return self.fail("Mismatch: Constant pattern does not match None.")
11381218
return self._match_constant(pattern_value, value)
1219+
if isinstance(pattern_value, OrValue):
1220+
if value is None:
1221+
return self.fail("Mismatch: OrValue pattern does not match None.")
1222+
i, pattern_choice = pattern_value.get_pattern(value)
1223+
result = self._match_value(pattern_choice, value)
1224+
if result:
1225+
if pattern_value.tag_var is not None:
1226+
self._match.bind(pattern_value.tag_var, i)
1227+
return result
11391228
return True
11401229

11411230
def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool:

onnxscript/rewriter/pattern_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,39 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
688688
self.assertEqual(len(model.graph), 2)
689689
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"])
690690

691+
def test_or_pattern(self):
692+
def source_pattern(op, x, y, bias):
693+
t1 = op.MatMul(x, y)
694+
t2 = op.Add(t1, bias)
695+
t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True])
696+
return op.Relu(t1_or_t2)
697+
698+
def replacement(op, x, y, bias, has_bias):
699+
if has_bias:
700+
return op.WithBias(x, y, bias)
701+
else:
702+
return op.WithoutBias(x, y)
703+
704+
rule = pattern.RewriteRule(source_pattern, replacement)
705+
706+
@script()
707+
def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]:
708+
return op.Relu(op.MatMul(x, y))
709+
710+
model_proto = test_model1.to_model_proto()
711+
model = ir.serde.deserialize_model(model_proto)
712+
rule.apply_to_model(model)
713+
self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"])
714+
715+
@script()
716+
def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]:
717+
return op.Relu(op.Add(op.MatMul(x, y), bias))
718+
719+
model_proto = test_model2.to_model_proto()
720+
model = ir.serde.deserialize_model(model_proto)
721+
rule.apply_to_model(model)
722+
self.assertEqual([x.op_type for x in model.graph], ["WithBias"])
723+
691724

692725
class PatternBuilderTest(unittest.TestCase):
693726
def test_pattern_builder_context(self):

0 commit comments

Comments
 (0)