diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 8ce05369c7..802cd37349 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -40,7 +40,6 @@ def __init__( transpose_4d: bool, pre_scale_q: bool, is_rotary: bool, - use_mask: bool, has_past_present: bool, is_cross_attention: bool, ): @@ -49,7 +48,6 @@ def __init__( self._transpose_4d = transpose_4d self._pre_scale_q = pre_scale_q self._is_rotary = is_rotary - self._use_mask = use_mask self._has_past_present = has_past_present self._is_cross_attention = is_cross_attention @@ -59,13 +57,11 @@ def pattern( query_BSD, key, value, - mask, past_key, past_value, position_ids, cos, sin, - key_perm, q_scale, ): # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H) @@ -80,6 +76,12 @@ def pattern( if not self._is_cross_attention: # Reshape from (B, S, D) to (B, S, H, D/H) key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"]) + # Key may or may not be transposed at this point, based on usage pattern + key = pattern.OrValue( + [op.Transpose(key, perm=[0, 2, 1, 3]), key], + tag_var="key_transposed", + tag_values=[True, False], + ) # Reshape from (B, S, D) to (B, S, H, D/H) value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"]) @@ -133,22 +135,14 @@ def pattern( key_seq_to_sdpa = key_seq value_seq_to_sdpa = value_seq - # TODO: Remove use_mask once SDPA op is usable - if self._use_mask: - sdpa = op.SDPA( - query_BHSDh_emb, - key_seq_to_sdpa, - value_seq_to_sdpa, - mask, - _domain="ai.onnxruntime._fusion", - ) - else: - sdpa = op.SDPA( - query_BHSDh_emb, - key_seq_to_sdpa, - value_seq_to_sdpa, - _domain="ai.onnxruntime._fusion", - ) + sdpa = op.SDPA( + query_BHSDh_emb, + key_seq_to_sdpa, + value_seq_to_sdpa, + _allow_other_inputs=True, + _outputs=["sdpa_output"], + _domain="ai.onnxruntime._fusion", + ) # Transpose attention back to (B, S, H, D/H) attention_transposed = op.Transpose(sdpa, perm=[0, 2, 1, 3]) @@ -167,17 +161,19 @@ def check( query_BSD, key, value, - mask, + sdpa_output, past_key, past_value, - key_perm, query_BSHDh, + key_transposed=None, key_BSHDh=None, value_BSHDh=None, **_, ) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() + sdpa_node = sdpa_output.producer() + bindings: dict[str, Dim] = {} def no_match(val: ir.Value, dims: Sequence[str]) -> bool: @@ -223,6 +219,13 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: f"Shape mismatch: {key} does not match expected dimensions ['B', 'Skv', 'D']", query_BSD, ) + sdpa_key_format = sdpa_node.attributes.get_string("key_format") + expected_key_format = "BHSd" if key_transposed else "BSHd" + if sdpa_key_format != expected_key_format: + return check_result.fail( + f"Unexpected key format: {sdpa_key_format}. Expected: {expected_key_format}", + sdpa_node, + ) if no_match(value, ["B", "Skv", "D"]): return check_result.fail( f"Shape mismatch: {value} does not match expected dimensions ['B', 'Skv', 'D']", @@ -245,7 +248,11 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # ORT's contrib ops (MHA, Attention) allow a mask of shape (1 or B, 1 or H, S, St) # That is: broadcast allowed only for the first two dimensions. (Even that is not # supported by some earlier versions of ORT, which are not supported here.) - if self._use_mask: + mask = None + if len(sdpa_node.inputs) > 3: + mask = sdpa_node.inputs[3] + self.mask = mask + if mask is not None: if (mask_shape := mask.shape) is None: return check_result.fail( "Mask shape cannot be determined.", @@ -293,7 +300,6 @@ def rewrite( query_BSD, key, value, - mask, past_key, past_value, query_BSHDh, @@ -335,6 +341,7 @@ def rewrite( query_BSD_emb = query_BSD key_BSD_emb = key + mask = self.mask if self._use_mask_broadcast: one = op.Constant(value_ints=[1]) S = op.Shape(query_BSD, start=1, end=2) @@ -365,7 +372,6 @@ def _make_rule_set(has_past_present: bool): "transpose_4d": transpose_4d, "pre_scale_q": pre_scale_q, "is_rotary": is_rotary, - "use_mask": use_mask, "has_past_present": has_past_present, "is_cross_attention": is_cross_attention, } @@ -375,7 +381,6 @@ def _make_rule_set(has_past_present: bool): ) # Only generate patterns when double_transpose is True for pre_scale_q in [True, False] for is_rotary in [False, True] - for use_mask in [False, True] for is_cross_attention in ([False] if has_past_present else [False, True]) ] @@ -387,7 +392,6 @@ def _make_rule_set(has_past_present: bool): f"{'_Twice' if params['double_transpose'] else ''}" f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" f"{'_Rotary' if params['is_rotary'] else ''}" - f"{'_Masked' if params['use_mask'] else ''}" f"{'_Past' if params['has_past_present'] else ''}" f"{'_CrossAttention' if params['is_cross_attention'] else ''}", **params,