|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""Basic types for the pattern matching and rewriter API.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import dataclasses |
| 8 | +import enum |
| 9 | +from collections import defaultdict |
| 10 | +from typing import TYPE_CHECKING, Any, MutableSequence, Sequence, Union |
| 11 | + |
| 12 | +from onnxscript import ir |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + import onnxscript.rewriter._pattern_ir as _pattern_ir |
| 16 | + import onnxscript.rewriter._rewrite_rule as _rewrite_rule |
| 17 | + |
| 18 | + |
| 19 | +class MatchResult: |
| 20 | + """The state object used by the pattern-matching algorithm. |
| 21 | +
|
| 22 | + A match can either succeed or fail. |
| 23 | + If it succeeds, it returns a list of nodes that matched the pattern |
| 24 | + and a set of bindings for the variables in the pattern. |
| 25 | +
|
| 26 | + Example: |
| 27 | + :: |
| 28 | + def pattern(x, shape1, shape2): |
| 29 | + t1 = op.Reshape(x, shape1) |
| 30 | + t2 = op.Reshape(t1, shape2) |
| 31 | + return t2 |
| 32 | + The above pattern matches a sequence of two Reshape ops. |
| 33 | + The matched_nodes will contain the two Reshape ops, and the bindings will |
| 34 | + contain the values that are bound to the variables `x`, `shape1`, and `shape2`. |
| 35 | + """ |
| 36 | + |
| 37 | + def __init__(self) -> None: |
| 38 | + # We use a stack of partial matches to handle OR patterns that require backtracking. |
| 39 | + self._partial_matches: list[PartialMatchResult] = [PartialMatchResult()] |
| 40 | + |
| 41 | + @property |
| 42 | + def _current_match(self) -> PartialMatchResult: |
| 43 | + """Returns the current match result.""" |
| 44 | + return self._partial_matches[-1] |
| 45 | + |
| 46 | + def enter_new_match(self) -> None: |
| 47 | + """Starts a new sub-match to try out one of multiple alternatives.""" |
| 48 | + match = PartialMatchResult() |
| 49 | + self._partial_matches.append(match) |
| 50 | + |
| 51 | + def abandon_current_match(self) -> PartialMatchResult: |
| 52 | + """Abandons the current alternative due to failure.""" |
| 53 | + if len(self._partial_matches) < 2: |
| 54 | + raise ValueError("No match to abandon.") |
| 55 | + return self._partial_matches.pop() |
| 56 | + |
| 57 | + def merge_current_match(self) -> None: |
| 58 | + """Merges a successful sub-match for an alternative with the parent one.""" |
| 59 | + if len(self._partial_matches) < 2: |
| 60 | + raise ValueError("No match to merge.") |
| 61 | + current_match = self._partial_matches.pop() |
| 62 | + previous_match = self._partial_matches[-1] |
| 63 | + if not current_match: |
| 64 | + raise ValueError("Current match is not successful.") |
| 65 | + # Merge the two matches. |
| 66 | + previous_match.merge(current_match) |
| 67 | + |
| 68 | + def __bool__(self) -> bool: |
| 69 | + """Returns True if the current match is successful.""" |
| 70 | + return bool(self._current_match) |
| 71 | + |
| 72 | + def fail( |
| 73 | + self, |
| 74 | + reason: str = "", |
| 75 | + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, |
| 76 | + ) -> MatchResult: |
| 77 | + self._current_match.fail(reason, failure_source) |
| 78 | + return self |
| 79 | + |
| 80 | + @property |
| 81 | + def reason(self) -> str: |
| 82 | + """Returns the reason for the failure.""" |
| 83 | + return self._current_match.reason |
| 84 | + |
| 85 | + @property |
| 86 | + def nodes(self) -> Sequence[ir.Node]: |
| 87 | + """Returns the list of nodes that matched the pattern.""" |
| 88 | + return self._current_match.nodes |
| 89 | + |
| 90 | + def bind_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node): |
| 91 | + """Binds a pattern node to a matched node.""" |
| 92 | + self.add_node(node) |
| 93 | + self._current_match.node_bindings[pattern_node] = node |
| 94 | + |
| 95 | + def add_node(self, node: ir.Node) -> None: |
| 96 | + """Adds a node to the list of matched nodes.""" |
| 97 | + self._current_match.add_node(node) |
| 98 | + |
| 99 | + def bind_value(self, pattern_value: _pattern_ir.ValuePattern, value: Any) -> bool: |
| 100 | + var_name = pattern_value.name |
| 101 | + # TODO(rama): Simplify the following. We currently bind values to |
| 102 | + # pattern variables in two different ways: via their name, or via the |
| 103 | + # pattern-value itself. |
| 104 | + if var_name is None: |
| 105 | + for match in self._partial_matches: |
| 106 | + if pattern_value in match.value_bindings: |
| 107 | + # TODO(rama): Use appropriate equality-check here. |
| 108 | + if match.value_bindings[pattern_value] == value: |
| 109 | + return True |
| 110 | + self._current_match.fail( |
| 111 | + f"Binding failure: {pattern_value} bound to two different values.", |
| 112 | + [match.value_bindings[pattern_value], value], |
| 113 | + ) |
| 114 | + return False |
| 115 | + self._current_match.value_bindings[pattern_value] = value |
| 116 | + return True |
| 117 | + return self.bind(var_name, value) |
| 118 | + |
| 119 | + def bind(self, var: str, value: Any) -> bool: |
| 120 | + for match in self._partial_matches: |
| 121 | + if var in match.bindings: |
| 122 | + # TODO(rama): Use appropriate equality-check here. |
| 123 | + if match.bindings[var] == value: |
| 124 | + return True |
| 125 | + self._current_match.fail( |
| 126 | + f"Binding failure: {var} bound to two different values.", |
| 127 | + [match.bindings[var], value], |
| 128 | + ) |
| 129 | + return False |
| 130 | + self._current_match.bindings[var] = value |
| 131 | + return True |
| 132 | + |
| 133 | + @property |
| 134 | + def bindings(self) -> dict[str, Any]: |
| 135 | + """Returns the bindings for the pattern variables.""" |
| 136 | + if len(self._partial_matches) > 1: |
| 137 | + raise ValueError("Bindings can be accessed only at the top-level match.") |
| 138 | + return self._current_match.bindings |
| 139 | + |
| 140 | + @property |
| 141 | + def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: |
| 142 | + """Returns the bindings for the value variables.""" |
| 143 | + if len(self._partial_matches) > 1: |
| 144 | + raise ValueError("Value bindings can be accessed only at the top-level match.") |
| 145 | + return self._current_match.value_bindings |
| 146 | + |
| 147 | + @property |
| 148 | + def outputs(self) -> MutableSequence[ir.Value]: |
| 149 | + """Returns the list of output values that matched the pattern.""" |
| 150 | + if len(self._partial_matches) > 1: |
| 151 | + raise ValueError("Outputs can be accessed only at the top-level match.") |
| 152 | + return self._current_match.outputs |
| 153 | + |
| 154 | + @property |
| 155 | + def failure_nodes_and_values(self) -> list[Union[ir.Node, ir.Value]]: |
| 156 | + """Returns the nodes and values that caused the failure.""" |
| 157 | + return self._current_match._failure_nodes_and_values |
| 158 | + |
| 159 | + def lookup_node(self, pattern_node: _pattern_ir.NodePattern) -> ir.Node | None: |
| 160 | + """Looks up the node that matched the given pattern node.""" |
| 161 | + for match in self._partial_matches: |
| 162 | + if pattern_node in match.node_bindings: |
| 163 | + return match.node_bindings[pattern_node] |
| 164 | + return None |
| 165 | + |
| 166 | + def num_matched_nodes(self) -> int: |
| 167 | + """Returns the number of nodes matched so far.""" |
| 168 | + return sum(len(match.node_bindings) for match in self._partial_matches) |
| 169 | + |
| 170 | + |
| 171 | +class PartialMatchResult: |
| 172 | + """The state object used by the pattern-matching algorithm for a sub-match.""" |
| 173 | + |
| 174 | + def __init__(self) -> None: |
| 175 | + self._success: bool = True |
| 176 | + # For a successful match, _matched_nodes is a list of values that matched the pattern. |
| 177 | + # These include the internal nodes of the pattern that were matched, but not |
| 178 | + # the leaves (sub-trees) that match against the variables in the pattern. |
| 179 | + # These represent the values that will be replaced by the replacement pattern. |
| 180 | + self._matched_nodes: MutableSequence[ir.Node] = [] |
| 181 | + # For a successful match, bindings is a dictionary of mapping pattern-variable-names |
| 182 | + # to values. |
| 183 | + self._bindings: dict[str, Any] = {} |
| 184 | + self._value_bindings: dict[_pattern_ir.ValuePattern, ir.Value] = {} |
| 185 | + self._node_bindings: dict[_pattern_ir.NodePattern, ir.Node] = {} |
| 186 | + |
| 187 | + self._outputs: list[ir.Value] = [] |
| 188 | + # For a failed match, _reason is a string that describes the reason for the failure. |
| 189 | + self._reason: str = "" |
| 190 | + # Track the node(s) or value(s) that caused the failure. |
| 191 | + self._failure_nodes_and_values: list[Union[ir.Node, ir.Value]] = [] |
| 192 | + |
| 193 | + def __bool__(self): |
| 194 | + return self._success |
| 195 | + |
| 196 | + def fail( |
| 197 | + self, |
| 198 | + reason: str = "", |
| 199 | + failure_source: Union[ir.Node, ir.Value, list[Union[ir.Node, ir.Value]]] | None = None, |
| 200 | + ) -> None: |
| 201 | + self._success = False |
| 202 | + self._reason = reason |
| 203 | + if failure_source is not None: |
| 204 | + if isinstance(failure_source, list): |
| 205 | + self._failure_nodes_and_values.extend(failure_source) |
| 206 | + else: |
| 207 | + self._failure_nodes_and_values.append(failure_source) |
| 208 | + |
| 209 | + @property |
| 210 | + def reason(self) -> str: |
| 211 | + return self._reason |
| 212 | + |
| 213 | + @property |
| 214 | + def nodes(self) -> Sequence[ir.Node]: |
| 215 | + return tuple(self._matched_nodes) |
| 216 | + |
| 217 | + def add_node(self, node: ir.Node) -> None: |
| 218 | + """Adds a node to the list of matched nodes.""" |
| 219 | + self._matched_nodes.append(node) |
| 220 | + |
| 221 | + @property |
| 222 | + def bindings(self) -> dict[str, Any]: |
| 223 | + return self._bindings |
| 224 | + |
| 225 | + @property |
| 226 | + def value_bindings(self) -> dict[_pattern_ir.ValuePattern, ir.Value]: |
| 227 | + return self._value_bindings |
| 228 | + |
| 229 | + @property |
| 230 | + def outputs(self) -> MutableSequence[ir.Value]: |
| 231 | + return self._outputs |
| 232 | + |
| 233 | + @property |
| 234 | + def node_bindings(self) -> dict[_pattern_ir.NodePattern, ir.Node]: |
| 235 | + return self._node_bindings |
| 236 | + |
| 237 | + def merge(self, other: PartialMatchResult) -> None: |
| 238 | + """Merges a successful sub-match for an alternative with the parent one.""" |
| 239 | + if self._success and other._success: |
| 240 | + # Merge the two successful matches. Matching algorithm responsible for ensuring |
| 241 | + # that the two matches are compatible. No need to check for conflicts here. |
| 242 | + self._bindings.update(other._bindings) |
| 243 | + self._matched_nodes.extend(other.nodes) |
| 244 | + # Note: outputs should be set only at end of the (top-level) match. There |
| 245 | + # should be no outputs in the sub-match. |
| 246 | + assert not other._outputs |
| 247 | + else: |
| 248 | + # This should not happen currently. |
| 249 | + raise NotImplementedError("Merging failed matches is not yet supported.") |
| 250 | + |
| 251 | + |
| 252 | +class MatchStatus(enum.IntEnum): |
| 253 | + """The status of a pattern-matching operation.""" |
| 254 | + |
| 255 | + NO_MATCH = 0 # No successful match found for entire pattern graph |
| 256 | + CONDITION_FAILED = 1 # Subsequent validation check failed |
| 257 | + REPLACEMENT_FAILED = 2 # Replacement subgraph could not be created |
| 258 | + SUCCESS = 3 # A successful match was found |
| 259 | + |
| 260 | + |
| 261 | +@dataclasses.dataclass |
| 262 | +class MatchInfo: |
| 263 | + """The status of a pattern-matching operation. An extension of MatchResult.""" |
| 264 | + |
| 265 | + match_result: MatchResult |
| 266 | + root_node: ir.Node |
| 267 | + container: ir.Graph | ir.Function |
| 268 | + status: MatchStatus |
| 269 | + |
| 270 | + def score(self) -> int: |
| 271 | + """Return a score for the match.""" |
| 272 | + return len(self.match_result.nodes) + int(self.status.value) * 100 |
| 273 | + |
| 274 | + def print(self): |
| 275 | + separator = "-" * 80 |
| 276 | + print(separator) |
| 277 | + print(f"Status: {self.status.name}") |
| 278 | + if self.status != MatchStatus.SUCCESS: |
| 279 | + reason = self.match_result.reason |
| 280 | + if reason: |
| 281 | + if self.status == MatchStatus.CONDITION_FAILED: |
| 282 | + print(f"Graph matching failed due to failing check condition : {reason}") |
| 283 | + else: |
| 284 | + print(f"Graph matching failed: {reason}") |
| 285 | + else: |
| 286 | + print("Graph matching failed.") |
| 287 | + failure_nodes_and_values = self.match_result.failure_nodes_and_values |
| 288 | + print("Failure at or around nodes/values:") |
| 289 | + if failure_nodes_and_values: |
| 290 | + for failure_cause in failure_nodes_and_values: |
| 291 | + failure_cause.display() |
| 292 | + print("Matched nodes:") |
| 293 | + import onnxscript.rewriter._ir_utils as ir_utils |
| 294 | + |
| 295 | + ir_utils.display_nodes(self.match_result.nodes) |
| 296 | + print(separator) |
| 297 | + |
| 298 | + |
| 299 | +class MatchingTracer: |
| 300 | + """A debugging helper class to trace the matching of a pattern against a graph. |
| 301 | +
|
| 302 | + This is used to track the best matches found for each rule, and to report the |
| 303 | + results at the end of the matching. |
| 304 | + """ |
| 305 | + |
| 306 | + def __init__(self) -> None: |
| 307 | + self._best_matches_map: dict[_rewrite_rule.RewriteRule, list[MatchInfo]] = defaultdict( |
| 308 | + list |
| 309 | + ) |
| 310 | + |
| 311 | + @property |
| 312 | + def best_matches_map(self) -> dict[_rewrite_rule.RewriteRule, list[MatchInfo]]: |
| 313 | + return self._best_matches_map |
| 314 | + |
| 315 | + def log( |
| 316 | + self, |
| 317 | + rule: _rewrite_rule.RewriteRule, |
| 318 | + container: ir.Graph | ir.Function, |
| 319 | + node: ir.Node, |
| 320 | + match_result: MatchResult, |
| 321 | + status: MatchStatus, |
| 322 | + ) -> None: |
| 323 | + this_match = MatchInfo(match_result, node, container, status) |
| 324 | + this_score = this_match.score() |
| 325 | + if this_score == 0: |
| 326 | + return |
| 327 | + best_matches = self._best_matches_map[rule] |
| 328 | + if best_matches: |
| 329 | + if this_score < best_matches[0].score(): |
| 330 | + return |
| 331 | + if this_score > best_matches[0].score(): |
| 332 | + best_matches.clear() |
| 333 | + best_matches.append(this_match) |
| 334 | + |
| 335 | + def report(self) -> None: |
| 336 | + best_score = 0 |
| 337 | + for rule, matches in self._best_matches_map.items(): |
| 338 | + if not matches: |
| 339 | + continue |
| 340 | + if matches[0].score() > best_score: |
| 341 | + best_score = matches[0].score() |
| 342 | + best_match = matches[0] |
| 343 | + best_rule = rule |
| 344 | + |
| 345 | + if best_score > 0: |
| 346 | + print(f"Rule: {best_rule}") |
| 347 | + best_match.print() |
| 348 | + else: |
| 349 | + print("No matches found.") |
0 commit comments