Skip to content

Add support for a non-backtracking version of pattern disjunction #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/rewriter_pattern.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
rewriter.pattern.NodeOutputPattern
rewriter.pattern.AnyValue
rewriter.pattern.Constant
rewriter.pattern.OrValue
rewriter.pattern.GraphPattern
rewriter.pattern.ReplacementSubgraph
rewriter.pattern.ReplacementPatternFunction
Expand Down
105 changes: 97 additions & 8 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
MutableSequence,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
)
Expand Down Expand Up @@ -511,7 +510,7 @@ def __init__(
if isinstance(op, str) and isinstance(domain, StringConstantPattern):
# TODO(rama): support overloaded operators.
overload = ""
self._op_identifier: tuple[str, str, str] | None = (
self._op_identifier: ir.OperatorIdentifier | None = (
domain.value(),
op,
overload,
Expand All @@ -535,7 +534,7 @@ def __str__(self) -> str:
inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs
return f"{outputs} = {qualified_op} ({inputs_and_attributes})"

def op_identifier(self) -> Tuple[str, str, str] | None:
def op_identifier(self) -> ir.OperatorIdentifier | None:
return self._op_identifier

@property
Expand Down Expand Up @@ -629,11 +628,6 @@ def producer(self) -> NodePattern:
Var = ValuePattern


def _is_pattern_variable(x: Any) -> bool:
# The derived classes of ValuePattern represent constant patterns and node-output patterns.
return type(x) is ValuePattern


class AnyValue(ValuePattern):
"""Represents a pattern that matches against any value."""

Expand Down Expand Up @@ -718,6 +712,92 @@ def __str__(self) -> str:
return str(self._value)


class OrValue(ValuePattern):
"""Represents a (restricted) form of value pattern disjunction."""

def __init__(
self,
values: Sequence[ValuePattern],
name: str | None = None,
tag_var: str | None = None,
tag_values: Sequence[Any] | None = None,
) -> None:
"""
Initialize an OrValue pattern.

Args:
values: A sequence of value patterns to match against.
Must contain at least two alternatives. All value patterns except the last one
must have a unique producer id. This allows the pattern-matching to be deterministic,
without the need for backtracking.
name: An optional variable name for the pattern. Defaults to None. If present,
this name will be bound to the value matched by the pattern.
tag_var: An optional variable name for the tag. Defaults to None. If present,
it will be bound to a value (from tag_values) indicating which alternative was matched.
tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
If present, the length of tag_values must match the number of alternatives in values.
In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
"""
super().__init__(name)
if len(values) < 2:
raise ValueError("OrValue must have at least two alternatives.")
if tag_values is not None:
if tag_var is None:
raise ValueError("tag_var must be specified if tag_values is provided.")
if len(tag_values) != len(values):
raise ValueError(
"tag_values must have the same length as the number of alternatives."
)
else:
tag_values = tuple(range(len(values)))
self._tag_var = tag_var
self._tag_values = tag_values
self._values = values

mapping: dict[ir.OperatorIdentifier, tuple[Any, NodeOutputPattern]] = {}
for i, alternative in enumerate(values[:-1]):
if not isinstance(alternative, NodeOutputPattern):
raise TypeError(
f"Invalid type {type(alternative)} for OrValue. Expected NodeOutputPattern."
)
producer = alternative.producer()
id = producer.op_identifier()
if id is None:
raise ValueError(
f"Invalid producer {producer} for OrValue. Expected a NodePattern with op identifier."
)
if id in mapping:
raise ValueError(
f"Invalid producer {producer} for OrValue. Expected a unique producer id for each alternative."
)
mapping[id] = (tag_values[i], alternative)
self._op_to_pattern = mapping
self._default_pattern = (tag_values[-1], values[-1])

@property
def tag_var(self) -> str | None:
"""Returns the tag variable associated with the OrValue pattern."""
return self._tag_var

def clone(self, node_map: dict[NodePattern, NodePattern]) -> OrValue:
return OrValue(
[v.clone(node_map) for v in self._values],
self.name,
self._tag_var,
self._tag_values,
)

def get_pattern(self, value: ir.Value) -> tuple[Any, ValuePattern]:
"""Returns the pattern that should be tried for the given value."""
producer = value.producer()
if producer is not None:
id = producer.op_identifier()
if id is not None and id in self._op_to_pattern:
return self._op_to_pattern[id]
return self._default_pattern


def _nodes_in_pattern(outputs: Sequence[ValuePattern]) -> list[NodePattern]:
"""Returns all nodes used in a pattern, given the outputs of the pattern."""
node_patterns: list[NodePattern] = []
Expand Down Expand Up @@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b
if value is None:
return self.fail("Mismatch: Constant pattern does not match None.")
return self._match_constant(pattern_value, value)
if isinstance(pattern_value, OrValue):
if value is None:
return self.fail("Mismatch: OrValue pattern does not match None.")
i, pattern_choice = pattern_value.get_pattern(value)
result = self._match_value(pattern_choice, value)
if result:
if pattern_value.tag_var is not None:
self._match.bind(pattern_value.tag_var, i)
return result
return True

def _match_node_output(self, pattern_value: NodeOutputPattern, value: ir.Value) -> bool:
Expand Down
33 changes: 33 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,39 @@ def test_model(x: FLOAT[1024], y: FLOAT[1024]) -> FLOAT[1024]:
self.assertEqual(len(model.graph), 2)
self.assertEqual([x.op_type for x in model.graph], ["Constant", "Identity"])

def test_or_pattern(self):
def source_pattern(op, x, y, bias):
t1 = op.MatMul(x, y)
t2 = op.Add(t1, bias)
t1_or_t2 = pattern.OrValue([t1, t2], tag_var="has_bias", tag_values=[False, True])
return op.Relu(t1_or_t2)

def replacement(op, x, y, bias, has_bias):
if has_bias:
return op.WithBias(x, y, bias)
else:
return op.WithoutBias(x, y)

rule = pattern.RewriteRule(source_pattern, replacement)

@script()
def test_model1(x: FLOAT[16, 32], y: FLOAT[32, 16]) -> FLOAT[16, 16]:
return op.Relu(op.MatMul(x, y))

model_proto = test_model1.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
self.assertEqual([x.op_type for x in model.graph], ["WithoutBias"])

@script()
def test_model2(x: FLOAT[16, 32], y: FLOAT[32, 16], bias: FLOAT[16]) -> FLOAT[16, 16]:
return op.Relu(op.Add(op.MatMul(x, y), bias))

model_proto = test_model2.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
rule.apply_to_model(model)
self.assertEqual([x.op_type for x in model.graph], ["WithBias"])


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading