Skip to content

Commit 88a6f75

Browse files
authored
Split pattern.py (#2296)
Splitting pattern.py into multiple files, each more focused: * _basics.py * _pattern_ir.py: the IR for pattern-graphs * _rewrite_rule.py: Rewrite Rules * _matcher.py: the pattern-matching algorithm There is more cleanup to be done in each part, but keeping this PR simple, focused only on the split-up, to avoid any major merge-issues.
1 parent db3dc8c commit 88a6f75

File tree

5 files changed

+2248
-2139
lines changed

5 files changed

+2248
-2139
lines changed

onnxscript/rewriter/_basics.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Basic types for the pattern matching and rewriter API."""
4+
5+
from __future__ import annotations
6+
7+
import dataclasses
8+
import enum
9+
from collections import defaultdict
10+
from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union
11+
12+
from onnxscript import ir
13+
14+
if TYPE_CHECKING:
15+
import onnxscript.rewriter._pattern_ir as _pattern_ir
16+
import onnxscript.rewriter._rewrite_rule as _rewrite_rule
17+
18+
19+
class MatchResult:
20+
"""The state object used by the pattern-matching algorithm.
21+
22+
A match can either succeed or fail.
23+
If it succeeds, it returns a list of nodes that matched the pattern
24+
and a set of bindings for the variables in the pattern.
25+
26+
Example:
27+
::
28+
def pattern(x, shape1, shape2):
29+
t1 = op.Reshape(x, shape1)
30+
t2 = op.Reshape(t1, shape2)
31+
return t2
32+
The above pattern matches a sequence of two Reshape ops.
33+
The matched_nodes will contain the two Reshape ops, and the bindings will
34+
contain the values that are bound to the variables `x`, `shape1`, and `shape2`.
35+
"""
36+
37+
def __init__(self) -> None:
38+
# We use a stack of partial matches to handle OR patterns that require backtracking.
39+
self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()]
40+
41+
@property
42+
def _current_match(self) -> PartialMatchResult:
43+
"""Returns the current match result."""
44+
return self._partial_matches[-1]
45+
46+
def enter_new_match(self) -> None:
47+
"""Starts a new sub-match to try out one of multiple alternatives."""
48+
match = PartialMatchResult()
49+
self._partial_matches.append(match)
50+
51+
def abandon_current_match(self) -> PartialMatchResult:
52+
"""Abandons the current alternative due to failure."""
53+
if len(self._partial_matches) < 2:
54+
raise ValueError("No match to abandon.")
55+
return self._partial_matches.pop()
56+
57+
def merge_current_match(self) -> None:
58+
"""Merges a successful sub-match for an alternative with the parent one."""
59+
if len(self._partial_matches) < 2:
60+
raise ValueError("No match to merge.")
61+
current_match = self._partial_matches.pop()
62+
previous_match = self._partial_matches[-1]
63+
if not current_match:
64+
raise ValueError("Current match is not successful.")
65+
# Merge the two matches.
66+
previous_match.merge(current_match)
67+
68+
def __bool__(self) -> bool:
69+
"""Returns True if the current match is successful."""
70+
return bool(self._current_match)
71+
72+
def fail(
73+
self,
74+
reason: str = "",
75+
failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
76+
) -> MatchResult:
77+
self._current_match.fail(reason, failure_source)
78+
return self
79+
80+
@property
81+
def reason(self) -> str:
82+
"""Returns the reason for the failure."""
83+
return self._current_match.reason
84+
85+
@property
86+
def nodes(self) -> Sequence[ir.Node]:
87+
"""Returns the list of nodes that matched the pattern."""
88+
return self._current_match.nodes
89+
90+
def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node):
91+
"""Binds a pattern node to a matched node."""
92+
self.add_node(node)
93+
self._current_match.node_bindings[pattern_node] = node
94+
95+
def add_node(self, node: ir.Node) -> None:
96+
"""Adds a node to the list of matched nodes."""
97+
self._current_match.add_node(node)
98+
99+
def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool:
100+
var_name = pattern_value.name
101+
# TODO(rama): Simplify the following. We currently bind values to
102+
# pattern variables in two different ways: via their name, or via the
103+
# pattern-value itself.
104+
if var_name is None:
105+
for match in self._partial_matches:
106+
if pattern_value in match.value_bindings:
107+
# TODO(rama): Use appropriate equality-check here.
108+
if match.value_bindings[pattern_value] == value:
109+
return True
110+
self._current_match.fail(
111+
f"Binding failure: {pattern_value} bound to two different values.",
112+
[match.value_bindings[pattern_value], value],
113+
)
114+
return False
115+
self._current_match.value_bindings[pattern_value] = value
116+
return True
117+
return self.bind(var_name, value)
118+
119+
def bind(self, var: str, value: Any) -> bool:
120+
for match in self._partial_matches:
121+
if var in match.bindings:
122+
# TODO(rama): Use appropriate equality-check here.
123+
if match.bindings[var] == value:
124+
return True
125+
self._current_match.fail(
126+
f"Binding failure: {var} bound to two different values.",
127+
[match.bindings[var], value],
128+
)
129+
return False
130+
self._current_match.bindings[var] = value
131+
return True
132+
133+
@property
134+
def bindings(self) -> dict[str, Any]:
135+
"""Returns the bindings for the pattern variables."""
136+
if len(self._partial_matches) > 1:
137+
raise ValueError("Bindings can be accessed only at the top-level match.")
138+
return self._current_match.bindings
139+
140+
@property
141+
def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
142+
"""Returns the bindings for the value variables."""
143+
if len(self._partial_matches) > 1:
144+
raise ValueError("Value bindings can be accessed only at the top-level match.")
145+
return self._current_match.value_bindings
146+
147+
@property
148+
def outputs(self) -> MutableSequence[ir.Value]:
149+
"""Returns the list of output values that matched the pattern."""
150+
if len(self._partial_matches) > 1:
151+
raise ValueError("Outputs can be accessed only at the top-level match.")
152+
return self._current_match.outputs
153+
154+
@property
155+
def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]:
156+
"""Returns the nodes and values that caused the failure."""
157+
return self._current_match._failure_nodes_and_values
158+
159+
def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None:
160+
"""Looks up the node that matched the given pattern node."""
161+
for match in self._partial_matches:
162+
if pattern_node in match.node_bindings:
163+
return match.node_bindings[pattern_node]
164+
return None
165+
166+
def num_matched_nodes(self) -> int:
167+
"""Returns the number of nodes matched so far."""
168+
return sum(len(match.node_bindings) for match in self._partial_matches)
169+
170+
171+
class PartialMatchResult:
172+
"""The state object used by the pattern-matching algorithm for a sub-match."""
173+
174+
def __init__(self) -> None:
175+
self._success: bool = True
176+
# For a successful match, _matched_nodes is a list of values that matched the pattern.
177+
# These include the internal nodes of the pattern that were matched, but not
178+
# the leaves (sub-trees) that match against the variables in the pattern.
179+
# These represent the values that will be replaced by the replacement pattern.
180+
self._matched_nodes: MutableSequence[ir.Node] = []
181+
# For a successful match, bindings is a dictionary of mapping pattern-variable-names
182+
# to values.
183+
self._bindings: dict[str, Any] = {}
184+
self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {}
185+
self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {}
186+
187+
self._outputs: list[ir.Value] = []
188+
# For a failed match, _reason is a string that describes the reason for the failure.
189+
self._reason: str = ""
190+
# Track the node(s) or value(s) that caused the failure.
191+
self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = []
192+
193+
def __bool__(self):
194+
return self._success
195+
196+
def fail(
197+
self,
198+
reason: str = "",
199+
failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None,
200+
) -> None:
201+
self._success = False
202+
self._reason = reason
203+
if failure_source is not None:
204+
if isinstance(failure_source, list):
205+
self._failure_nodes_and_values.extend(failure_source)
206+
else:
207+
self._failure_nodes_and_values.append(failure_source)
208+
209+
@property
210+
def reason(self) -> str:
211+
return self._reason
212+
213+
@property
214+
def nodes(self) -> Sequence[ir.Node]:
215+
return tuple(self._matched_nodes)
216+
217+
def add_node(self, node: ir.Node) -> None:
218+
"""Adds a node to the list of matched nodes."""
219+
self._matched_nodes.append(node)
220+
221+
@property
222+
def bindings(self) -> dict[str, Any]:
223+
return self._bindings
224+
225+
@property
226+
def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]:
227+
return self._value_bindings
228+
229+
@property
230+
def outputs(self) -> MutableSequence[ir.Value]:
231+
return self._outputs
232+
233+
@property
234+
def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]:
235+
return self._node_bindings
236+
237+
def merge(self, other: PartialMatchResult) -> None:
238+
"""Merges a successful sub-match for an alternative with the parent one."""
239+
if self._success and other._success:
240+
# Merge the two successful matches. Matching algorithm responsible for ensuring
241+
# that the two matches are compatible. No need to check for conflicts here.
242+
self._bindings.update(other._bindings)
243+
self._matched_nodes.extend(other.nodes)
244+
# Note: outputs should be set only at end of the (top-level) match. There
245+
# should be no outputs in the sub-match.
246+
assert not other._outputs
247+
else:
248+
# This should not happen currently.
249+
raise NotImplementedError("Merging failed matches is not yet supported.")
250+
251+
252+
class MatchStatus(enum.IntEnum):
253+
"""The status of a pattern-matching operation."""
254+
255+
NO_MATCH = 0 # No successful match found for entire pattern graph
256+
CONDITION_FAILED = 1 # Subsequent validation check failed
257+
REPLACEMENT_FAILED = 2 # Replacement subgraph could not be created
258+
SUCCESS = 3 # A successful match was found
259+
260+
261+
@dataclasses.dataclass
262+
class MatchInfo:
263+
"""The status of a pattern-matching operation. An extension of MatchResult."""
264+
265+
match_result: MatchResult
266+
root_node: ir.Node
267+
container: ir.Graph | ir.Function
268+
status: MatchStatus
269+
270+
def score(self) -> int:
271+
"""Return a score for the match."""
272+
return len(self.match_result.nodes) + int(self.status.value) * 100
273+
274+
def print(self):
275+
separator = "-" * 80
276+
print(separator)
277+
print(f"Status: {self.status.name}")
278+
if self.status != MatchStatus.SUCCESS:
279+
reason = self.match_result.reason
280+
if reason:
281+
if self.status == MatchStatus.CONDITION_FAILED:
282+
print(f"Graph matching failed due to failing check condition : {reason}")
283+
else:
284+
print(f"Graph matching failed: {reason}")
285+
else:
286+
print("Graph matching failed.")
287+
failure_nodes_and_values = self.match_result.failure_nodes_and_values
288+
print("Failure at or around nodes/values:")
289+
if failure_nodes_and_values:
290+
for failure_cause in failure_nodes_and_values:
291+
failure_cause.display()
292+
print("Matched nodes:")
293+
import onnxscript.rewriter._ir_utils as ir_utils
294+
295+
ir_utils.display_nodes(self.match_result.nodes)
296+
print(separator)
297+
298+
299+
class MatchingTracer:
300+
"""A debugging helper class to trace the matching of a pattern against a graph.
301+
302+
This is used to track the best matches found for each rule, and to report the
303+
results at the end of the matching.
304+
"""
305+
306+
def __init__(self) -> None:
307+
self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict(
308+
list
309+
)
310+
311+
@property
312+
def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]:
313+
return self._best_matches_map
314+
315+
def log(
316+
self,
317+
rule: _rewrite_rule.RewriteRule,
318+
container: ir.Graph | ir.Function,
319+
node: ir.Node,
320+
match_result: MatchResult,
321+
status: MatchStatus,
322+
) -> None:
323+
this_match = MatchInfo(match_result, node, container, status)
324+
this_score = this_match.score()
325+
if this_score == 0:
326+
return
327+
best_matches = self._best_matches_map[rule]
328+
if best_matches:
329+
if this_score < best_matches[0].score():
330+
return
331+
if this_score > best_matches[0].score():
332+
best_matches.clear()
333+
best_matches.append(this_match)
334+
335+
def report(self) -> None:
336+
best_score = 0
337+
for rule, matches in self._best_matches_map.items():
338+
if not matches:
339+
continue
340+
if matches[0].score() > best_score:
341+
best_score = matches[0].score()
342+
best_match = matches[0]
343+
best_rule = rule
344+
345+
if best_score > 0:
346+
print(f"Rule: {best_rule}")
347+
best_match.print()
348+
else:
349+
print("No matches found.")

0 commit comments

Comments
 (0)