18
18
MutableSequence ,
19
19
Protocol ,
20
20
Sequence ,
21
- Tuple ,
22
21
TypeVar ,
23
22
Union ,
24
23
)
@@ -511,7 +510,7 @@ def __init__(
511
510
if isinstance (op , str ) and isinstance (domain , StringConstantPattern ):
512
511
# TODO(rama): support overloaded operators.
513
512
overload = ""
514
- self ._op_identifier : tuple [ str , str , str ] | None = (
513
+ self ._op_identifier : ir . OperatorIdentifier | None = (
515
514
domain .value (),
516
515
op ,
517
516
overload ,
@@ -535,7 +534,7 @@ def __str__(self) -> str:
535
534
inputs_and_attributes = f"{ inputs } , { attributes } " if attributes else inputs
536
535
return f"{ outputs } = { qualified_op } ({ inputs_and_attributes } )"
537
536
538
- def op_identifier (self ) -> Tuple [ str , str , str ] | None :
537
+ def op_identifier (self ) -> ir . OperatorIdentifier | None :
539
538
return self ._op_identifier
540
539
541
540
@property
@@ -629,11 +628,6 @@ def producer(self) -> NodePattern:
629
628
Var = ValuePattern
630
629
631
630
632
- def _is_pattern_variable (x : Any ) -> bool :
633
- # The derived classes of ValuePattern represent constant patterns and node-output patterns.
634
- return type (x ) is ValuePattern
635
-
636
-
637
631
class AnyValue (ValuePattern ):
638
632
"""Represents a pattern that matches against any value."""
639
633
@@ -718,6 +712,92 @@ def __str__(self) -> str:
718
712
return str (self ._value )
719
713
720
714
715
+ class OrValue (ValuePattern ):
716
+ """Represents a (restricted) form of value pattern disjunction."""
717
+
718
+ def __init__ (
719
+ self ,
720
+ values : Sequence [ValuePattern ],
721
+ name : str | None = None ,
722
+ tag_var : str | None = None ,
723
+ tag_values : Sequence [Any ] | None = None ,
724
+ ) -> None :
725
+ """
726
+ Initialize an OrValue pattern.
727
+
728
+ Args:
729
+ values: A sequence of value patterns to match against.
730
+ Must contain at least two alternatives. All value patterns except the last one
731
+ must have a unique producer id. This allows the pattern-matching to be deterministic,
732
+ without the need for backtracking.
733
+ name: An optional variable name for the pattern. Defaults to None. If present,
734
+ this name will be bound to the value matched by the pattern.
735
+ tag_var: An optional variable name for the tag. Defaults to None. If present,
736
+ it will be bound to a value (from tag_values) indicating which alternative was matched.
737
+ tag_values: An optional sequence of values to bind to the tag_var. Defaults to None.
738
+ If present, the length of tag_values must match the number of alternatives in values.
739
+ In a successful match, tag-var will be bound to the i-th value in tag_values if the i-th
740
+ alternative pattern matched. If omitted, the default value of (0, 1, 2, ...) will be used.
741
+ """
742
+ super ().__init__ (name )
743
+ if len (values ) < 2 :
744
+ raise ValueError ("OrValue must have at least two alternatives." )
745
+ if tag_values is not None :
746
+ if tag_var is None :
747
+ raise ValueError ("tag_var must be specified if tag_values is provided." )
748
+ if len (tag_values ) != len (values ):
749
+ raise ValueError (
750
+ "tag_values must have the same length as the number of alternatives."
751
+ )
752
+ else :
753
+ tag_values = tuple (range (len (values )))
754
+ self ._tag_var = tag_var
755
+ self ._tag_values = tag_values
756
+ self ._values = values
757
+
758
+ mapping : dict [ir .OperatorIdentifier , tuple [Any , NodeOutputPattern ]] = {}
759
+ for i , alternative in enumerate (values [:- 1 ]):
760
+ if not isinstance (alternative , NodeOutputPattern ):
761
+ raise TypeError (
762
+ f"Invalid type { type (alternative )} for OrValue. Expected NodeOutputPattern."
763
+ )
764
+ producer = alternative .producer ()
765
+ id = producer .op_identifier ()
766
+ if id is None :
767
+ raise ValueError (
768
+ f"Invalid producer { producer } for OrValue. Expected a NodePattern with op identifier."
769
+ )
770
+ if id in mapping :
771
+ raise ValueError (
772
+ f"Invalid producer { producer } for OrValue. Expected a unique producer id for each alternative."
773
+ )
774
+ mapping [id ] = (tag_values [i ], alternative )
775
+ self ._op_to_pattern = mapping
776
+ self ._default_pattern = (tag_values [- 1 ], values [- 1 ])
777
+
778
+ @property
779
+ def tag_var (self ) -> str | None :
780
+ """Returns the tag variable associated with the OrValue pattern."""
781
+ return self ._tag_var
782
+
783
+ def clone (self , node_map : dict [NodePattern , NodePattern ]) -> OrValue :
784
+ return OrValue (
785
+ [v .clone (node_map ) for v in self ._values ],
786
+ self .name ,
787
+ self ._tag_var ,
788
+ self ._tag_values ,
789
+ )
790
+
791
+ def get_pattern (self , value : ir .Value ) -> tuple [Any , ValuePattern ]:
792
+ """Returns the pattern that should be tried for the given value."""
793
+ producer = value .producer ()
794
+ if producer is not None :
795
+ id = producer .op_identifier ()
796
+ if id is not None and id in self ._op_to_pattern :
797
+ return self ._op_to_pattern [id ]
798
+ return self ._default_pattern
799
+
800
+
721
801
def _nodes_in_pattern (outputs : Sequence [ValuePattern ]) -> list [NodePattern ]:
722
802
"""Returns all nodes used in a pattern, given the outputs of the pattern."""
723
803
node_patterns : list [NodePattern ] = []
@@ -1136,6 +1216,15 @@ def _match_value(self, pattern_value: ValuePattern, value: ir.Value | None) -> b
1136
1216
if value is None :
1137
1217
return self .fail ("Mismatch: Constant pattern does not match None." )
1138
1218
return self ._match_constant (pattern_value , value )
1219
+ if isinstance (pattern_value , OrValue ):
1220
+ if value is None :
1221
+ return self .fail ("Mismatch: OrValue pattern does not match None." )
1222
+ i , pattern_choice = pattern_value .get_pattern (value )
1223
+ result = self ._match_value (pattern_choice , value )
1224
+ if result :
1225
+ if pattern_value .tag_var is not None :
1226
+ self ._match .bind (pattern_value .tag_var , i )
1227
+ return result
1139
1228
return True
1140
1229
1141
1230
def _match_node_output (self , pattern_value : NodeOutputPattern , value : ir .Value ) -> bool :
0 commit comments