Skip to content

Conversation

tadani3
Copy link

@tadani3 tadani3 commented Jul 24, 2025

Modification of the GQA causal mask fusion rule to handle the attention mask fusion for Longrope models such as Phi-4-mini-reasoning. The causal mask modification leads to a result that matches the optimizations made in ModelBuilder.

mask_key = self._get_mask_key(attention_mask)

if mask_key in self._mask_cache:
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable total_seq_length_int32 is not used.
mask_key = self._get_mask_key(attention_mask)

if mask_key in self._mask_cache:
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable seqlens_k_int32 is not used.
Copy link

codecov bot commented Jul 24, 2025

Codecov Report

❌ Patch coverage is 32.53235% with 365 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.03%. Comparing base (da23d76) to head (2772f77).

Files with missing lines Patch % Lines
...ipt/rewriter/phi4_mini_reasoning_post_processor.py 18.24% 345 Missing ⚠️
onnxscript/rewriter/ort_fusions/gqa.py 83.19% 20 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2461      +/-   ##
==========================================
- Coverage   69.81%   69.03%   -0.78%     
==========================================
  Files         209      210       +1     
  Lines       25313    25790     +477     
  Branches     2525     2603      +78     
==========================================
+ Hits        17673    17805     +132     
- Misses       6762     7110     +348     
+ Partials      878      875       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds causal mask pattern fusion support specifically for LongRoPe models such as Phi-4-mini-reasoning. The implementation extends the existing GQA (Group Query Attention) fusion rules to handle the complex attention mask patterns used by LongRoPe models, optimizing the mask computation process while maintaining compatibility with ModelBuilder optimizations.

Key changes:

  • Addition of a new LongRoPeGQACausalMask class that implements specialized mask pattern matching and fusion
  • Extension of the GQA rewrite rules to include LongRoPe-specific optimizations
  • Implementation of mask caching mechanism to avoid recomputation


_basic_gqa_rule = GroupQueryAttention.rule()
_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule()

gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) is completely ignored.

Suggested change
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])

Copilot uses AI. Check for mistakes.

# Propagation to GQA
mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"])

#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"])
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.

Suggested change
#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"])

Copilot uses AI. Check for mistakes.

mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"])

# EXPAND A/B TO AND
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"])
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.

Suggested change
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"])
mask_expanded_A_sub = op.Sub(mask_expanded_A, MASK_OFFSET, _outputs=["mask_expanded_A_sub"])

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op.

Comment on lines 365 to 368
Generate a unique key for the mask based on input_ids and past_kv_cache.
This is used to cache the mask to avoid recomputation.
"""
return (id(attention_mask))
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using id() for cache keys is fragile because object ids can be reused after garbage collection. This could lead to incorrect cache hits with different attention_mask objects that happen to have the same id.

Suggested change
Generate a unique key for the mask based on input_ids and past_kv_cache.
This is used to cache the mask to avoid recomputation.
"""
return (id(attention_mask))
Generate a unique key for the mask based on the content of attention_mask.
This is used to cache the mask to avoid recomputation.
"""
if isinstance(attention_mask, np.ndarray):
return hash(attention_mask.tobytes())
elif isinstance(attention_mask, (list, tuple)):
return hash(tuple(attention_mask))
else:
raise TypeError("Unsupported type for attention_mask: {}".format(type(attention_mask)))

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats on your first PR to onnxscript! Could you resolve the lint errors by running lintrunner -a and fix any issues?

"""
return (id(attention_mask))

def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask?

@@ -354,9 +355,163 @@ def rewrite(
_outputs=3,
)

class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast).

I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting.

total_seq_length,
):
seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"])
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"])
seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"])

prefer snake case for variable names when possible

mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"])
mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"])

# EXPAND B/C TO AND
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would document the branches in plain English for readers

class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase):
def __init__(self):
super().__init__("LongRoPeGQACausalMask", remove_nodes=False)
self._mask_cache = {}
Copy link
Collaborator

@justinchuby justinchuby Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a self._mask_cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class?

@gramalingam
Copy link
Collaborator

Hi @tadani3 , sorry about the concurrent changes I had merged into GQA fusion recently, which might impact some of your changes ... but I am a bit confused by the diffs shown, which don't seem to reflect the changes I had made, so I am a bit confused. Briefly, the earlier version did the fusion into two steps, the first rule ignore the attention-mask, and focused on the rest of the computation, and the second rule, explicitly handles the attention-mask. The more recent version merged the two into one, for various reasons.

I think it shouldn't impact your changes much, except that you will have to make the changes in rule 1 instead of rule 2. But, as I said, I am bit confused why I am not seeing those in the diffs

super().__init__("LongRoPeGQACausalMask", remove_nodes=False)
self._mask_cache = {}

def _get_mask_key(self, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, avoid creating class methods that do not require states from self, and instead make them module-level private functions for testability and clarity.

@@ -7,6 +7,7 @@
import numpy as np
import onnx_ir as ir

import onnxscript.onnx_types as _onnx_types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import onnxscript.onnx_types as _onnx_types

_onnx_types is incompatible with the rewriter (yet)

@tadani3
Copy link
Author

tadani3 commented Jul 24, 2025

@microsoft-github-policy-service agree company="Microsoft"

# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import onnx

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'onnx' is not used.
# --------------------------------------------------------------------------
import onnx
from onnxscript import ir
import onnx.helper

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'onnx' is not used.
cache_length = self.rotemb_attrs["cache_length"]
position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'inv_freq' may be used before it is initialized.
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2)
emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim)
cos_cache = emb.cos() * attention_factor # (1, cache_length, dim)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'attention_factor' may be used before it is initialized.
attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"]

inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'ext_factors' may be used before it is initialized.
if "rescale_inv_freq" in self.rotemb_attrs:
inv_freq = self.make_inv_freq_rescaled(inv_freq)

return inv_freq, attention_factor

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'attention_factor' may be used before it is initialized.
@tadani3 tadani3 changed the title Added Causal Mask Pattern Fusion for LongRoPe Models Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning Jul 31, 2025
@tadani3
Copy link
Author

tadani3 commented Jul 31, 2025

I added a class called Phi4MiniReasoningPostProcessor which uses the ONNX IR to fulfill two tasks:

  1. Inserts the cached Cos/Sin values used in the Rotary Embeddings
  • The Cos/Sin caches are computed offline following the logic found in Transformer's modeling_phi3.py file.
  • The existing pattern that computes Cos and Sin is found in the ONNX model graph and removed.
  • An If node containing both sets of caches is inserted to the graph. The selection condition for the two sets of caches is the total_sequence_length, which we obtain from the attention_mask.
  • GQA nodes within the graph are updated to have their cos/sin inputs come from the newly inserted caches.
  1. Removes Position Ids and remaining child nodes from the model graph.
  • Matches the position ids and following nodes that were used in the old Cos/Sin value computation.
  • Uses the ONNX IR to remove these nodes from the graph.

Reasoning and Motivation

  • These changes allow the version of Phi4-mini-reasoning produced using the DynamoExporter to match the version found in Model Builder

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can move this file to a separate PR we can merge the fusion rules. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

3 participants