diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 1f4c0c39d8..e0d9331065 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -86,20 +86,18 @@ def fuse(func, **kwargs): # We apply shape inference after the SDPA fusion as new nodes are added # in the rewrite rule for certain patterns of SDPA. fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True) - # Optimize to avoid trying multiple attention-based fusions + + fusion_count["gqa"] = fuse(fuse_gqa) + fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa) + fusion_count["mha1"] = fuse(fuse_mha1) fusion_count["mha2"] = fuse(fuse_mha2) if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0): - # If no MHA fusion was applied, we can try the GQA fusion. - # and avoid trying the attention fusion. - fusion_count["gqa"] = fuse(fuse_gqa) - fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa) fusion_count["mha_bias"] = 0 fusion_count["attention"] = 0 else: fusion_count["mha_bias"] = fuse(fuse_mha_bias) fusion_count["attention"] = fuse(fuse_attention) - fusion_count["gqa"] = 0 fusion_count["gelu"] = fuse(fuse_gelu) fusion_count["bias_gelu"] = fuse(fuse_bias_gelu) # Finally: inline any intermediate fusion functions introduced that were not diff --git a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py index 74405bbe44..b2f0e3af8d 100644 --- a/onnxscript/rewriter/ort_fusions/cos_sin_cache.py +++ b/onnxscript/rewriter/ort_fusions/cos_sin_cache.py @@ -106,7 +106,16 @@ def cleanup(self): self._inv_freq_cos_sin_cache.clear() def pattern( - self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, extra_dims + self, + op, + x, + inv_freq, + position_ids, + interleaved, + num_heads, + freqs, + dtype, + extra_dims, ): if not self._const_freqs: # Compute freqs from inv_freq and position_ids. In the _const_freqs case, @@ -121,6 +130,13 @@ def pattern( # if self._reshape: # position_ids_expanded = op.Expand(position_ids_expanded, _allow_other_inputs=True) # position_ids_expanded = op.Reshape(position_ids_expanded, _allow_other_inputs=True) + # inv_freq may optionally be expanded to shape [B, E, 1] + inv_freq = pattern.OrValue( + [ + op.Expand(inv_freq, pattern.ANY_VALUE, _outputs=["expanded_inv_freq"]), + inv_freq, + ] + ) freqs = op.MatMul(inv_freq, position_ids_expanded) # [B, E, S] # if self._reshape: # freqs = op.Reshape(freqs, freqs_3d_shape) # redundant reshape @@ -140,11 +156,11 @@ def pattern( sin_4d, interleaved=interleaved, num_heads=num_heads, - _domain="ai.onnxruntime.fusion", + _domain="ai.onnxruntime._fusion", ) def check( - self, context, inv_freq, position_ids, freqs, extra_dims, **_ + self, context, inv_freq, position_ids, freqs, extra_dims, expanded_inv_freq=None, **_ ) -> pattern.MatchResult: # type: ignore[name-defined] check_result = pattern.MatchResult() # TODO(rama): handle redundant reshape/expand @@ -164,6 +180,10 @@ def check( if not _ir_utils.has_rank(inv_freq, 3): return check_result.fail("inv_freq is not 3D.", inv_freq) inv_freq_shape = inv_freq.shape + if expanded_inv_freq is not None: + if not _ir_utils.has_rank(expanded_inv_freq, 3): + return check_result.fail("expanded_inv_freq is not 3D.", expanded_inv_freq) + # TODO: check expanded_inv_freq shape if inv_freq.const_value is None: # TODO: should this be inv_freq_shape? return check_result.fail("inv_freq is not a constant.", inv_freq) if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1: diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 266987dd4d..0ea3718bb0 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -78,13 +78,11 @@ def pattern( value_BSDkv, past_key, past_value, - input_ids, - past_seq_length, - total_seq_length, + position_ids_q, + position_ids_k, cos, sin, - some_kv_cache, - shape_B111, + mask, ): # Reshape query from (B, S, D) to (B, S, H, D/H) query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) @@ -101,10 +99,6 @@ def pattern( # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) - position_ids = op.Range(past_seq_length, total_seq_length, 1) - position_ids_q = op.Unsqueeze(position_ids, [0]) - position_ids_k = op.Unsqueeze(position_ids, [0]) - query_BHSDh_rope = op.RotaryEmbedding( query_BHSDh, position_ids_q, @@ -141,15 +135,13 @@ def pattern( value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] ) - mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) - - key_seq_BHDhT = op.Transpose(key_seq_BHTDh, perm=[0, 1, 3, 2]) attention_BHSDh = op.SDPA( query_BHSDh_rope, - key_seq_BHDhT, + key_seq_BHTDh, value_seq_BHTDh, mask, - _domain="ai.onnxruntime.fusion", + key_format="BHSd", + _domain="ai.onnxruntime._fusion", ) # Transpose attention back to (B, S, H, D/H) @@ -209,8 +201,8 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: # Rotary embedding attributes query_rotary_attributes = query_BHSDh_rope.producer().attributes key_rotary_attributes = key_BHkvSDh_rope.producer().attributes - query_interleaved = query_rotary_attributes.get("interleaved", 0) - key_interleaved = key_rotary_attributes.get("interleaved", 0) + query_interleaved = query_rotary_attributes.get_int("interleaved", 0) + key_interleaved = key_rotary_attributes.get_int("interleaved", 0) if query_interleaved != key_interleaved: return pattern.MatchResult().fail( "Rotary embedding interleaved attribute mismatch", @@ -228,42 +220,104 @@ def rewrite( value_BSDkv, past_key, past_value, - total_seq_length, + position_ids_q, + position_ids_k, cos, sin, + mask, **_, ): - total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) - one_0D = op.Constant(value_int=1) - one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) - seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) - zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) - seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) - - return op.GroupQueryAttention( + return op.GQA( + mask, + position_ids_k, + position_ids_q, query_BSD, key_BSDkv, value_BSDkv, past_key, past_value, - seqlens_k, - total_seq_length_int32, + None, # seqlens_k, + None, # total_seq_length_int32, cos, sin, - # mask, # TODO: this is not a valid input for GQA num_heads=self.num_heads, kv_num_heads=self.kv_num_heads, do_rotary=1, rotary_interleaved=self._interleaved, # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap - _domain="com.microsoft", + _domain="ai.onnxruntime._fusion", _outputs=3, ) -_rule1 = GroupQueryAttention.rule() +class GQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQACausalMask", remove_nodes=False) + + def pattern( + self, + op, + mask, + input_ids, + some_kv_cache, + shape_B111, + past_seq_length, + total_seq_length, + ): + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + return op.GQA( + mask, + position_ids_k, + position_ids_q, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + def rewrite( + self, + op, + total_seq_length, + attn_output, + **_, + ): + # Construct total_seq_length_int32 and seqlens_k + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0D = op.Constant(value_int=1) + one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) + zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + -gqa_rules = pattern.RewriteRuleSet([_rule1]) +_basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule]) fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/ort_fusions/gqa_test.py b/onnxscript/rewriter/ort_fusions/gqa_test.py index 18d79d24d0..494dfb8daa 100644 --- a/onnxscript/rewriter/ort_fusions/gqa_test.py +++ b/onnxscript/rewriter/ort_fusions/gqa_test.py @@ -307,6 +307,11 @@ def test_fusion(self): onnx.TensorProto.FLOAT, ["B", self.seqlen, self.num_heads, self.head_size], ) + key_BHSDh_value_info = onnx.helper.make_tensor_value_info( + "key_BHSDh", + onnx.TensorProto.FLOAT, + ["B", self.num_heads, self.total_seqlen, self.head_size], + ) key_BSHkvDh_value_info = onnx.helper.make_tensor_value_info( "key_BSHkvDh", onnx.TensorProto.FLOAT, @@ -327,6 +332,7 @@ def test_fusion(self): query_BHSDh_rope_value_info, key_BHkvSDh_rope_value_info, query_BSHDh_value_info, + key_BHSDh_value_info, key_BSHkvDh_value_info, key_transposed_value_info, value_BHSDh_value_info, @@ -338,10 +344,10 @@ def test_fusion(self): onnxscript.optimizer.optimize(inferred_model) count = fuse_sdpa(inferred_model, debug=True) - self.assertEqual(count, 1) + self.assertGreater(count, 0) count = fuse_gqa(inferred_model, debug=True) - self.assertEqual(count, 1) + self.assertGreater(count, 0) fused_model = ir.serde.to_proto(inferred_model) session = ort.InferenceSession( diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 03b0506867..8ce05369c7 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -79,17 +79,7 @@ def pattern( if not self._is_cross_attention: # Reshape from (B, S, D) to (B, S, H, D/H) - key_BSHDh = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"]) - - # Possible Transpose patterns for key: - # This scenario optimizes the need for a double transpose - # 1. (B, S, H, D/H) -> (B, H, D/H, S) - # Patterns with double transpose of key - # Double transpose should handle this optimization - # 2. (B, S, H, D/H) -> (B, H, S, D/H) -> (B, H, D/H, S) - # Patterns where key is reshaped to 3D, transposed and reshaped back to 4D - # 3. (B, S, H, D/H) -> (B, H, S, D/H) -> R (B, S, D) -> (B, D, S) -> R (B, H, D/H, S) - key_BHSDh = op.Transpose(key_BSHDh, perm=key_perm) + key = op.Reshape(key, pattern.ANY_VALUE, _outputs=["key_BSHDh"]) # Reshape from (B, S, D) to (B, S, H, D/H) value_BSHDh = op.Reshape(value, pattern.ANY_VALUE, _outputs=["value_BSHDh"]) @@ -97,7 +87,6 @@ def pattern( value_BHSDh = op.Transpose(value_BSHDh, perm=[0, 2, 1, 3]) else: # For cross-attention, key and value are not reshaped - key_BHSDh = key value_BHSDh = value if self._is_rotary: @@ -118,14 +107,14 @@ def pattern( ) if not self._is_cross_attention: key_BHSDh_emb = op.RotaryEmbedding( - key_BHSDh, position_ids_k, cos, sin, _domain="com.microsoft" + key, position_ids_k, cos, sin, _domain="com.microsoft" ) else: - key_BHSDh_emb = key_BHSDh + key_BHSDh_emb = key else: # If rotary embedding is not used, we fuse with positional_embeddings query_BHSDh_emb = query_BHSDh - key_BHSDh_emb = key_BHSDh + key_BHSDh_emb = key # Concatenate past_key cache and current key, and transpose to enable # dot-product attention computation. @@ -144,20 +133,6 @@ def pattern( key_seq_to_sdpa = key_seq value_seq_to_sdpa = value_seq - # Transpose last two axes of key_seq to compute dot-product via matmul. - if self._double_transpose: - if self._transpose_4d: - key_seq_to_sdpa = op.Transpose(key_seq_to_sdpa, perm=[0, 1, 3, 2]) - else: - # Transpose after converting to 3D - key_seq_BH_Skv_Dh = op.Reshape( - key_seq_to_sdpa, pattern.ANY_VALUE, _outputs=["key_seq_BH_Skv_Dh"] - ) - key_seq_BH_Dh_Skv = op.Transpose(key_seq_BH_Skv_Dh, perm=[0, 2, 1]) - key_seq_to_sdpa = op.Reshape( - key_seq_BH_Dh_Skv, pattern.ANY_VALUE, _outputs=["key_seq_B_H_Dh_Skv"] - ) - # TODO: Remove use_mask once SDPA op is usable if self._use_mask: sdpa = op.SDPA( @@ -165,14 +140,14 @@ def pattern( key_seq_to_sdpa, value_seq_to_sdpa, mask, - _domain="ai.onnxruntime.fusion", + _domain="ai.onnxruntime._fusion", ) else: sdpa = op.SDPA( query_BHSDh_emb, key_seq_to_sdpa, value_seq_to_sdpa, - _domain="ai.onnxruntime.fusion", + _domain="ai.onnxruntime._fusion", ) # Transpose attention back to (B, S, H, D/H) diff --git a/onnxscript/rewriter/ort_fusions/mha_test.py b/onnxscript/rewriter/ort_fusions/mha_test.py index 8d1c04f970..236f5bcff9 100644 --- a/onnxscript/rewriter/ort_fusions/mha_test.py +++ b/onnxscript/rewriter/ort_fusions/mha_test.py @@ -57,7 +57,7 @@ def test_whisper_encoder(self): original_outputs = ort_run("original", model, inputs) # Fuse SDPA and MHA - sdpa_count = xformers.fuse_sdpa(model) + sdpa_count = xformers.fuse_sdpa(model, debug=True) self.assertGreater(sdpa_count, 0) model = common_passes.ShapeInferencePass()(model).model mha_count = xformers.fuse_mha1(model) diff --git a/onnxscript/rewriter/ort_fusions/rotary_embedding.py b/onnxscript/rewriter/ort_fusions/rotary_embedding.py index 0c2a527620..b9d4015f06 100644 --- a/onnxscript/rewriter/ort_fusions/rotary_embedding.py +++ b/onnxscript/rewriter/ort_fusions/rotary_embedding.py @@ -56,7 +56,7 @@ def check(self, op, x, start1, end1, start2, end2, **_) -> pattern.MatchResult: def rewrite(self, op, x, cos, sin, **_): num_heads = x.shape[1] return op.RotaryEmbedding( - x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime._fusion" ) diff --git a/onnxscript/rewriter/ort_fusions/sdpa.py b/onnxscript/rewriter/ort_fusions/sdpa.py index 1ca4c3b1ff..1d339f43e7 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa.py +++ b/onnxscript/rewriter/ort_fusions/sdpa.py @@ -20,13 +20,34 @@ def pattern( self, op, query, - key_transposed, + key, value, mask, query_scale, key_scale, qk_scale, ): + # The last two axes of key must be transposed before computing the dot product with query. + # Three patterns are observed in practice: + + # Pattern 1: Transpose 4D key directly: BHSd => BHdS + key_transposed_1 = op.Transpose(key, perm=[0, 1, 3, 2]) + + # Pattern 2: Transpose key after converting to 3D and then convert back to 4D: BHSd => 3D => BHdS + key_3d = op.Reshape(key, pattern.ANY_VALUE) + key_3d_transposed = op.Transpose(key_3d, perm=[0, 2, 1]) + key_transposed_2 = op.Reshape(key_3d_transposed, pattern.ANY_VALUE) + + # Pattern 3: This transpose is sometimes composed with an earlier transpose to convert + # the key from BSHd format to BHSd format. + key_transposed_3 = op.Transpose(key, perm=[0, 2, 3, 1]) + + key_transposed = pattern.OrValue( + [key_transposed_1, key_transposed_2, key_transposed_3], + tag_var="key_format", + tag_values=["BHSd", "BHSd", "BSHd"], + ) + # Some implementations scale the query and key before computing the dot product query = pattern.OrValue( [ @@ -74,9 +95,10 @@ def check( self, context, query: ir.Value | None, - key_transposed: ir.Value | None, + key: ir.Value | None, value: ir.Value | None, mask: ir.Value | None, + key_format: str, **match_bindings, ): check_result = pattern.MatchResult() @@ -90,7 +112,11 @@ def check( # Query and Key should have same head-size (Dh) while value can have different head-size (Dv). # Key and Value should have same sequence length (Skv), while Query can have different sequence length (S). _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) + if key_format == "BHSd": + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) + else: + assert key_format == "BSHd", f"Unexpected key format: {key_format}" + _fusion_utils.check_shape(bindings, key, ["B", "Skv", "H", "Dh"]) _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) def get_scale_value(tag_name: str, scale_name: str) -> float: @@ -132,20 +158,29 @@ def rewrite( self, op, query: ir.Value | None, - key_transposed: ir.Value | None, + key: ir.Value | None, value: ir.Value | None, mask: ir.Value | None, + key_format: str, **_, ): - sdpa_args = [query, key_transposed, value] + sdpa_args = [query, key, value] if mask is not None: sdpa_args.append(mask) # If the scale is None, SDPA will use the default scaling factor, which is 1/sqrt(head_size). - return op.SDPA(*sdpa_args, scale=self._scale, _domain="ai.onnxruntime.fusion") + return op.SDPA( + *sdpa_args, + scale=self._scale, + key_format=key_format, + _domain="ai.onnxruntime._fusion", + ) # Dynamically create the rules -sdpa_rules = pattern.RewriteRuleSet([SDPA.rule()]) - +sdpa_rules = pattern.RewriteRuleSet( + [ + SDPA.rule(), + ] +) fuse_sdpa = _fusion_utils.apply_fusion_rules(sdpa_rules) diff --git a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py index 502e19093a..54c41217ca 100644 --- a/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py +++ b/onnxscript/rewriter/ort_fusions/sdpa_via_mha.py @@ -11,20 +11,21 @@ class SDPAImplementation(pattern.RewriteRuleClassBase): - def pattern(self, op, query, key_transposed, value): + def pattern(self, op, query, key, value): return op.SDPA( query, - key_transposed, + key, value, + key_format="BHSd", _allow_other_inputs=True, # Mask is optional _outputs=["sdpa_output"], - _domain="ai.onnxruntime.fusion", + _domain="ai.onnxruntime._fusion", ) - def check(self, context, query, key_transposed, value, sdpa_output): + def check(self, context, query, key, value, sdpa_output): bindings: dict[str, Dim] = {} _fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"]) - _fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"]) + _fusion_utils.check_shape(bindings, key, ["B", "H", "Skv", "Dh"]) _fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"]) self._num_heads = bindings["H"] @@ -33,13 +34,13 @@ def check(self, context, query, key_transposed, value, sdpa_output): self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed return isinstance(self._num_heads, int) - def rewrite(self, op, query, key_transposed, value, sdpa_output): + def rewrite(self, op, query, key, value, sdpa_output): sdpa_node = sdpa_output.producer() scale = sdpa_node.attributes.get("scale", None) to_3d_shape = op.Constant(value_ints=[0, 0, -1]) to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1]) query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape) - key_3d = op.Reshape(op.Transpose(key_transposed, perm=[0, 3, 1, 2]), to_3d_shape) + key_3d = op.Reshape(op.Transpose(key, perm=[0, 2, 1, 3]), to_3d_shape) value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape) inputs = [query_3d, key_3d, value_3d]