-
Notifications
You must be signed in to change notification settings - Fork 81
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning #2461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mask_key = self._get_mask_key(attention_mask) | ||
|
||
if mask_key in self._mask_cache: | ||
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
mask_key = self._get_mask_key(attention_mask) | ||
|
||
if mask_key in self._mask_cache: | ||
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2461 +/- ##
==========================================
- Coverage 69.81% 69.03% -0.78%
==========================================
Files 209 210 +1
Lines 25313 25790 +477
Branches 2525 2603 +78
==========================================
+ Hits 17673 17805 +132
- Misses 6762 7110 +348
+ Partials 878 875 -3 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds causal mask pattern fusion support specifically for LongRoPe models such as Phi-4-mini-reasoning. The implementation extends the existing GQA (Group Query Attention) fusion rules to handle the complex attention mask patterns used by LongRoPe models, optimizing the mask computation process while maintaining compatibility with ModelBuilder optimizations.
Key changes:
- Addition of a new
LongRoPeGQACausalMask
class that implements specialized mask pattern matching and fusion - Extension of the GQA rewrite rules to include LongRoPe-specific optimizations
- Implementation of mask caching mechanism to avoid recomputation
|
||
_basic_gqa_rule = GroupQueryAttention.rule() | ||
_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() | ||
|
||
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])
is completely ignored.
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) |
Copilot uses AI. Check for mistakes.
# Propagation to GQA | ||
mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) | ||
|
||
#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.
#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) | |
Copilot uses AI. Check for mistakes.
mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) | ||
|
||
# EXPAND A/B TO AND | ||
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) | |
mask_expanded_A_sub = op.Sub(mask_expanded_A, MASK_OFFSET, _outputs=["mask_expanded_A_sub"]) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op.
Generate a unique key for the mask based on input_ids and past_kv_cache. | ||
This is used to cache the mask to avoid recomputation. | ||
""" | ||
return (id(attention_mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using id() for cache keys is fragile because object ids can be reused after garbage collection. This could lead to incorrect cache hits with different attention_mask objects that happen to have the same id.
Generate a unique key for the mask based on input_ids and past_kv_cache. | |
This is used to cache the mask to avoid recomputation. | |
""" | |
return (id(attention_mask)) | |
Generate a unique key for the mask based on the content of attention_mask. | |
This is used to cache the mask to avoid recomputation. | |
""" | |
if isinstance(attention_mask, np.ndarray): | |
return hash(attention_mask.tobytes()) | |
elif isinstance(attention_mask, (list, tuple)): | |
return hash(tuple(attention_mask)) | |
else: | |
raise TypeError("Unsupported type for attention_mask: {}".format(type(attention_mask))) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Congrats on your first PR to onnxscript! Could you resolve the lint errors by running lintrunner -a
and fix any issues?
""" | ||
return (id(attention_mask)) | ||
|
||
def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask?
@@ -354,9 +355,163 @@ def rewrite( | |||
_outputs=3, | |||
) | |||
|
|||
class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast).
I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting.
total_seq_length, | ||
): | ||
seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) | ||
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) | |
seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) |
prefer snake case for variable names when possible
mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) | ||
mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) | ||
|
||
# EXPAND B/C TO AND |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would document the branches in plain English for readers
class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): | ||
def __init__(self): | ||
super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||
self._mask_cache = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a self._mask_cache
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class?
Hi @tadani3 , sorry about the concurrent changes I had merged into GQA fusion recently, which might impact some of your changes ... but I am a bit confused by the diffs shown, which don't seem to reflect the changes I had made, so I am a bit confused. Briefly, the earlier version did the fusion into two steps, the first rule ignore the attention-mask, and focused on the rest of the computation, and the second rule, explicitly handles the attention-mask. The more recent version merged the two into one, for various reasons. I think it shouldn't impact your changes much, except that you will have to make the changes in rule 1 instead of rule 2. But, as I said, I am bit confused why I am not seeing those in the diffs |
super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||
self._mask_cache = {} | ||
|
||
def _get_mask_key(self, attention_mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, avoid creating class methods that do not require states from self, and instead make them module-level private functions for testability and clarity.
@@ -7,6 +7,7 @@ | |||
import numpy as np | |||
import onnx_ir as ir | |||
|
|||
import onnxscript.onnx_types as _onnx_types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import onnxscript.onnx_types as _onnx_types |
_onnx_types is incompatible with the rewriter (yet)
@microsoft-github-policy-service agree company="Microsoft" |
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
import onnx |
Check notice
Code scanning / CodeQL
Unused import Note
# -------------------------------------------------------------------------- | ||
import onnx | ||
from onnxscript import ir | ||
import onnx.helper |
Check notice
Code scanning / CodeQL
Unused import Note
cache_length = self.rotemb_attrs["cache_length"] | ||
position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) | ||
|
||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
with torch.autocast(device_type=device_type, enabled=False): | ||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) | ||
emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) | ||
cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] | ||
|
||
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim | ||
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
if "rescale_inv_freq" in self.rotemb_attrs: | ||
inv_freq = self.make_inv_freq_rescaled(inv_freq) | ||
|
||
return inv_freq, attention_factor |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
I added a class called Phi4MiniReasoningPostProcessor which uses the ONNX IR to fulfill two tasks:
Reasoning and Motivation
|
…xscript into longrope_causal_mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can move this file to a separate PR we can merge the fusion rules. Thanks
Modification of the GQA causal mask fusion rule to handle the attention mask fusion for Longrope models such as Phi-4-mini-reasoning. The causal mask modification leads to a result that matches the optimizations made in ModelBuilder.