-
Notifications
You must be signed in to change notification settings - Fork 81
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning #2461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added Causal Mask Pattern Fusion for LongRoPe Models and Cache Insertion for Phi4-mini-reasoning #2461
Changes from 2 commits
7bd391d
f0f41a8
189d0c8
758e92d
d4a8c57
30faab7
01e37b3
912a80b
fd95719
19d2656
0742db2
2772f77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||
import onnx_ir as ir | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import onnxscript.onnx_types as _onnx_types | ||||||||||||||||||||||||||||
import onnxscript.rewriter._fusion_utils as _fusion_utils | ||||||||||||||||||||||||||||
from onnxscript.rewriter import _basics, _ir_utils, pattern | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -354,9 +355,163 @@ | |||||||||||||||||||||||||||
_outputs=3, | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast). I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting. |
||||||||||||||||||||||||||||
def __init__(self): | ||||||||||||||||||||||||||||
super().__init__("LongRoPeGQACausalMask", remove_nodes=False) | ||||||||||||||||||||||||||||
self._mask_cache = {} | ||||||||||||||||||||||||||||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class? |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
def _get_mask_key(self, attention_mask): | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, avoid creating class methods that do not require states from self, and instead make them module-level private functions for testability and clarity. |
||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
Generate a unique key for the mask based on input_ids and past_kv_cache. | ||||||||||||||||||||||||||||
This is used to cache the mask to avoid recomputation. | ||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||
return (id(attention_mask)) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using id() for cache keys is fragile because object ids can be reused after garbage collection. This could lead to incorrect cache hits with different attention_mask objects that happen to have the same id.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys. |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): | ||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner RUFF/UP037 Warning
Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation Check failureCode scanning / lintrunner RUFF/F821 Error
Undefined name batch.
See https://docs.astral.sh/ruff/rules/undefined-name Check warningCode scanning / lintrunner RUFF/UP037 Warning
Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation Check failureCode scanning / lintrunner RUFF/F821 Error
Undefined name seq\_len.
See https://docs.astral.sh/ruff/rules/undefined-name |
||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask? |
||||||||||||||||||||||||||||
mask_key = self._get_mask_key(attention_mask) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if mask_key in self._mask_cache: | ||||||||||||||||||||||||||||
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] | ||||||||||||||||||||||||||||
Check noticeCode scanning / CodeQL Unused local variable Note
Variable total_seq_length_int32 is not used.
Check noticeCode scanning / CodeQL Unused local variable Note
Variable seqlens_k_int32 is not used.
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||
# Construct total_seq_length_int32 and seqlens_k | ||||||||||||||||||||||||||||
attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) | ||||||||||||||||||||||||||||
total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) | ||||||||||||||||||||||||||||
reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) | ||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||||||||||||||||||||||||||||
sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) | ||||||||||||||||||||||||||||
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) | ||||||||||||||||||||||||||||
seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) | ||||||||||||||||||||||||||||
self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
return self._mask_cache[mask_key] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def pattern( | ||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||
op, | ||||||||||||||||||||||||||||
mask, | ||||||||||||||||||||||||||||
input_ids, | ||||||||||||||||||||||||||||
past_kv_cache_1, | ||||||||||||||||||||||||||||
past_kv_cache_2, | ||||||||||||||||||||||||||||
attention_mask, | ||||||||||||||||||||||||||||
past_seq_length, | ||||||||||||||||||||||||||||
total_seq_length, | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) | ||||||||||||||||||||||||||||
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
prefer snake case for variable names when possible |
||||||||||||||||||||||||||||
past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) | ||||||||||||||||||||||||||||
past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) | ||||||||||||||||||||||||||||
total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# All of the Add node's outputs | ||||||||||||||||||||||||||||
current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"]) | ||||||||||||||||||||||||||||
total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"]) | ||||||||||||||||||||||||||||
current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) | ||||||||||||||||||||||||||||
total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) | ||||||||||||||||||||||||||||
total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) | ||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner RUFF/F841 Warning
Local variable total\_seq\_len\_final is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
# EXPAND BRANCH A | ||||||||||||||||||||||||||||
batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) | ||||||||||||||||||||||||||||
mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) | ||||||||||||||||||||||||||||
mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"]) | ||||||||||||||||||||||||||||
reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"]) | ||||||||||||||||||||||||||||
mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# EXPAND BRANCH B | ||||||||||||||||||||||||||||
mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"]) | ||||||||||||||||||||||||||||
mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) | ||||||||||||||||||||||||||||
reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) | ||||||||||||||||||||||||||||
mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W293 Warning
Blank line contains whitespace.
See https://docs.astral.sh/ruff/rules/blank-line-with-whitespace |
||||||||||||||||||||||||||||
# EXPAND BRANCH C | ||||||||||||||||||||||||||||
mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) | ||||||||||||||||||||||||||||
mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) | ||||||||||||||||||||||||||||
batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) | ||||||||||||||||||||||||||||
batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) | ||||||||||||||||||||||||||||
reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"]) | ||||||||||||||||||||||||||||
mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# EXPAND A/B TO AND | ||||||||||||||||||||||||||||
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op. |
||||||||||||||||||||||||||||
mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"]) | ||||||||||||||||||||||||||||
mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"]) | ||||||||||||||||||||||||||||
mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"]) | ||||||||||||||||||||||||||||
mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) | ||||||||||||||||||||||||||||
mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# EXPAND B/C TO AND | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would document the branches in plain English for readers |
||||||||||||||||||||||||||||
unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"]) | ||||||||||||||||||||||||||||
unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"]) | ||||||||||||||||||||||||||||
mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"]) | ||||||||||||||||||||||||||||
attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) | ||||||||||||||||||||||||||||
mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"]) | ||||||||||||||||||||||||||||
mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"]) | ||||||||||||||||||||||||||||
mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"]) | ||||||||||||||||||||||||||||
mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE) | ||||||||||||||||||||||||||||
# Propagation to GQA | ||||||||||||||||||||||||||||
mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
return op.GQA( | ||||||||||||||||||||||||||||
mask_sliced, | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # position_ids_k | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # position_ids_q | ||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner EDITORCONFIG-CHECKER/editorconfig Warning
Trailing whitespace
Check warningCode scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||||||||||||||||||||||||||||
pattern.ANY_VALUE, # query | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # key | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # value | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # past_key | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # past_value | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # seqlens_k (optional) | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # total_seq_length (optional) | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # cos | ||||||||||||||||||||||||||||
pattern.ANY_VALUE, # sin | ||||||||||||||||||||||||||||
_allow_other_inputs=True, | ||||||||||||||||||||||||||||
_domain="ai.onnxruntime._fusion", | ||||||||||||||||||||||||||||
_outputs=["attn_output", "key_seq", "value_seq"], | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def rewrite( | ||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||
op, | ||||||||||||||||||||||||||||
attention_mask, | ||||||||||||||||||||||||||||
attn_output, | ||||||||||||||||||||||||||||
**_, | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
# Compute total_seq_length_int32 and seqlens_k_int32 | ||||||||||||||||||||||||||||
total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
gqa_node = attn_output.producer() | ||||||||||||||||||||||||||||
assert len(gqa_node.inputs) == 12, ( | ||||||||||||||||||||||||||||
f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
query, key, value, past_key, past_value = gqa_node.inputs[3:8] | ||||||||||||||||||||||||||||
cos, sin = gqa_node.inputs[10:12] | ||||||||||||||||||||||||||||
updated_inputs = [ | ||||||||||||||||||||||||||||
query, | ||||||||||||||||||||||||||||
key, | ||||||||||||||||||||||||||||
value, | ||||||||||||||||||||||||||||
past_key, | ||||||||||||||||||||||||||||
past_value, | ||||||||||||||||||||||||||||
seqlens_k_int32, | ||||||||||||||||||||||||||||
total_seq_length_int32, | ||||||||||||||||||||||||||||
cos, | ||||||||||||||||||||||||||||
sin, | ||||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||||
attributes = gqa_node.attributes | ||||||||||||||||||||||||||||
return op.GroupQueryAttention( | ||||||||||||||||||||||||||||
*updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
_basic_gqa_rule = GroupQueryAttention.rule() | ||||||||||||||||||||||||||||
_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) | ||||||||||||||||||||||||||||
Check warningCode scanning / CodeQL Variable defined multiple times Warning
This assignment to 'gqa_rules' is unnecessary as it is
redefined Error loading related location Loading |
||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||||||||||||||||||||||||||||
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _longrope_gqa_causal_mask_rule]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) | ||||||||||||||||||||||||||||
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) | ||||||||||||||||||||||||||||
Check warningCode scanning / lintrunner RUFF/W292 Warning
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_onnx_types is incompatible with the rewriter (yet)