Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 156 additions & 1 deletion onnxscript/rewriter/ort_fusions/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import onnx_ir as ir

import onnxscript.onnx_types as _onnx_types
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import onnxscript.onnx_types as _onnx_types

_onnx_types is incompatible with the rewriter (yet)

import onnxscript.rewriter._fusion_utils as _fusion_utils
from onnxscript.rewriter import _basics, _ir_utils, pattern

Expand Down Expand Up @@ -354,9 +355,163 @@
_outputs=3,
)

class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use the docstring to document the pattern and its replacement? For the branches A, B, and C, I would consider giving them descriptive names.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is my understanding: if this is correct, maybe they can be renamed appropriately: I believe that A constructs the kv_range, B constructs the query_range, and C constructs the batch_range. Each constructs the corresponding range as a 4D tensor with 1s in other position (for constructing a final attention-mask of shape [Batch, NumHeads, QueryRange, KVRange] via broadcast).

I am a bit puzzled that query_range and kv_range look to be the same here, it might be an artifact of this model-usage, I guess.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what the branches referred to but I'll make changes following what Rama is suggesting.

def __init__(self):
super().__init__("LongRoPeGQACausalMask", remove_nodes=False)
self._mask_cache = {}
Copy link
Collaborator

@justinchuby justinchuby Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copilot review is reasonable: the rewrite rule class should be stateless. Is there a different way to do this other than keeping a self._mask_cache?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think use of state for this purpose is okay? It has been used before for a similar purpose: which is to introduce values that are reused across multiple rewrites. (Now that we have CSE, there is an alternative path, which is to create duplicate copies and then eliminate them via CSE ... but I am not sure it is worth the bother.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW: my GQA fusion doesn't use state, and produces multiple copies (as described above).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern is that the states will transfer from model to another if not careful, which is probably not a good idea. Maybe we can have a class managed state dict that will be cleared by the class?


Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def _get_mask_key(self, attention_mask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, avoid creating class methods that do not require states from self, and instead make them module-level private functions for testability and clarity.

"""
Generate a unique key for the mask based on input_ids and past_kv_cache.
This is used to cache the mask to avoid recomputation.
"""
return (id(attention_mask))
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using id() for cache keys is fragile because object ids can be reused after garbage collection. This could lead to incorrect cache hits with different attention_mask objects that happen to have the same id.

Suggested change
Generate a unique key for the mask based on input_ids and past_kv_cache.
This is used to cache the mask to avoid recomputation.
"""
return (id(attention_mask))
Generate a unique key for the mask based on the content of attention_mask.
This is used to cache the mask to avoid recomputation.
"""
if isinstance(attention_mask, np.ndarray):
return hash(attention_mask.tobytes())
elif isinstance(attention_mask, (list, tuple)):
return hash(tuple(attention_mask))
else:
raise TypeError("Unsupported type for attention_mask: {}".format(type(attention_mask)))

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a cache is used, it should be cleaned up like in this example so that it is not carried over from one graph/model to another

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I am not sure if we need to handle np arrays? If the key is either one or two ir.Values, that should be fine ... ir.Values can be used as keys in dictionaries directly, and that should avoid the garbage-collection problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree _get_mask_key seems unecessary. We can use the Value objects directly as keys.


Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']):

Check warning

Code scanning / lintrunner

RUFF/UP037 Warning

Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Check warning

Code scanning / lintrunner

RUFF/UP037 Warning

Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rewriter doesn't use onnxscript type (yet). Could you instead use a comment to document the shape of the attention_mask?

mask_key = self._get_mask_key(attention_mask)

if mask_key in self._mask_cache:
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable total_seq_length_int32 is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable seqlens_k_int32 is not used.

else:
# Construct total_seq_length_int32 and seqlens_k
attention_shape = op.Shape(attention_mask, _outputs=["seq_len"])
total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"])
reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"])

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"])
total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"])
seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"])
self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

return self._mask_cache[mask_key]

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning


def pattern(
self,
op,
mask,
input_ids,
past_kv_cache_1,
past_kv_cache_2,
attention_mask,
past_seq_length,
total_seq_length,
):
seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"])
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"])
seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"])

prefer snake case for variable names when possible

past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"])
past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"])
total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"])

# All of the Add node's outputs
current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"])
total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"])
current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"])
total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"])
total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"])

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"])

Check warning

Code scanning / lintrunner

RUFF/F841 Warning

Local variable total\_seq\_len\_final is assigned to but never used.
See https://docs.astral.sh/ruff/rules/unused-variable

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# EXPAND BRANCH A
batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"])
mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"])
mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"])
reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"])
mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"])

# EXPAND BRANCH B
mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"])
mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"])
reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"])
mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"])

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

# EXPAND BRANCH C
mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"])
mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"])
batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"])
batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"])
reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"])
mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"])

# EXPAND A/B TO AND
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"])
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 262144 should be defined as a named constant to improve code readability and maintainability. Consider defining it as a class constant with a descriptive name.

Suggested change
mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"])
mask_expanded_A_sub = op.Sub(mask_expanded_A, MASK_OFFSET, _outputs=["mask_expanded_A_sub"])

Copilot uses AI. Check for mistakes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to make it a pattern-variable, I think ... if I understand right, this is actually a magic sequence-length constant? Perhaps model-specific?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, I am guessing this is the window_size, which should become an attribute-parameter to the GQA op.

mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"])
mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"])
mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"])
mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"])
mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"])

# EXPAND B/C TO AND
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would document the branches in plain English for readers

unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"])
unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"])
mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"])
attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"])
mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"])

mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"])
mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"])
mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"])
mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE)
# Propagation to GQA
mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"])

#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"])
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out code should be removed if it's not needed, or properly implemented if it serves a purpose. Dead code reduces maintainability.

Suggested change
#mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"])

Copilot uses AI. Check for mistakes.


return op.GQA(
mask_sliced,
pattern.ANY_VALUE, # position_ids_k
pattern.ANY_VALUE, # position_ids_q

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

pattern.ANY_VALUE, # query
pattern.ANY_VALUE, # key
pattern.ANY_VALUE, # value
pattern.ANY_VALUE, # past_key
pattern.ANY_VALUE, # past_value
pattern.ANY_VALUE, # seqlens_k (optional)
pattern.ANY_VALUE, # total_seq_length (optional)
pattern.ANY_VALUE, # cos
pattern.ANY_VALUE, # sin
_allow_other_inputs=True,
_domain="ai.onnxruntime._fusion",
_outputs=["attn_output", "key_seq", "value_seq"],
)


def rewrite(
self,
op,
attention_mask,
attn_output,
**_,
):
# Compute total_seq_length_int32 and seqlens_k_int32
total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask)

gqa_node = attn_output.producer()
assert len(gqa_node.inputs) == 12, (
f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}"
)
query, key, value, past_key, past_value = gqa_node.inputs[3:8]
cos, sin = gqa_node.inputs[10:12]
updated_inputs = [
query,
key,
value,
past_key,
past_value,
seqlens_k_int32,
total_seq_length_int32,
cos,
sin,
]
attributes = gqa_node.attributes
return op.GroupQueryAttention(
*updated_inputs, **attributes, _domain="com.microsoft", _outputs=3
)

_basic_gqa_rule = GroupQueryAttention.rule()
_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule()

gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'gqa_rules' is unnecessary as it is
redefined
before this value is used.
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gqa_rules variable is being reassigned, which overwrites the previous assignment on line 514. This means the first assignment gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) is completely ignored.

Suggested change
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])

Copilot uses AI. Check for mistakes.

gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _longrope_gqa_causal_mask_rule])

fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

Loading
Loading