Skip to content

Conversation

tadani3
Copy link

@tadani3 tadani3 commented Aug 1, 2025

This PR introduces a specialized LongRoPe (Long Range Rotary Position Embedding) GQA (Group Query Attention) causal mask fusion rule specifically designed for Phi-4-mini-reasoning and similar models. The implementation optimizes attention mask computation for models using sliding window attention with LongRoPe position embeddings.

New LongRoPeGQACausalMask Class

  • Specialized Pattern Matching: Implements complex pattern matching for LongRoPe attention mechanisms with sliding window support.
  • Mask Caching: Introduces caching using _get_mask_key() to avoid recomputation of expensive mask operations across layers.
  • Sliding Window Support: Handles configurable sliding window sizes (currently hardcoded to 262144) for long-context attention.

Advanced Mask Computation

  • Multi-Branch Processing: Implements three parallel branches for KV range, query range, and batch processing.
  • Efficient Range Operations: Uses optimized tensor operations for creating position-based masks.
  • Boolean Logic Optimization: Combines sliding window masks with attention mask lookups using efficient boolean operations.

Note: This PR is meant to replace #2461 by introducing the requested changes.

"""
Pattern for LongRoPe GQA Causal Mask.
This pattern computes the causal mask for Group Query Attention with LongRoPe.
It constructs the mask based on input_ids and past_kv_cache, and handles the

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable total_seq_length_int32 is not used.
"""
Pattern for LongRoPe GQA Causal Mask.
This pattern computes the causal mask for Group Query Attention with LongRoPe.
It constructs the mask based on input_ids and past_kv_cache, and handles the

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable seqlens_k_int32 is not used.
mask_key = _get_mask_key(attention_mask)

if mask_key in self._mask_cache:
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable total_seq_length_int32 is not used.
mask_key = _get_mask_key(attention_mask)

if mask_key in self._mask_cache:
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key]

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable seqlens_k_int32 is not used.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import onnx

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'onnx' is not used.
# --------------------------------------------------------------------------
import onnx
from onnxscript import ir
import onnx.helper

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'onnx' is not used.
cache_length = self.rotemb_attrs["cache_length"]
position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length)

inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'inv_freq' may be used before it is initialized.
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2)
emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim)
cos_cache = emb.cos() * attention_factor # (1, cache_length, dim)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'attention_factor' may be used before it is initialized.
attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"]

inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'ext_factors' may be used before it is initialized.
if "rescale_inv_freq" in self.rotemb_attrs:
inv_freq = self.make_inv_freq_rescaled(inv_freq)

return inv_freq, attention_factor

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'attention_factor' may be used before it is initialized.
Copy link
Contributor

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

tadani3 and others added 13 commits August 1, 2025 17:40
Provide a way to indicate that a pattern-variable can match successfully
against a None-valued input. Cleanup current handling which was
inconsistent in one place. Add test cases.

---------

Signed-off-by: Ganesan Ramalingam <[email protected]>
Co-authored-by: Copilot <[email protected]>
This PR adds comprehensive documentation for the rewriter pattern
options that were previously undocumented. The rewriter pattern system
supports four key options for controlling pattern matching and
replacement behavior:

## New Documentation Added

### `_allow_other_inputs` option
- **File**: `docs/tutorial/rewriter/allow_other_inputs.md`
- **Purpose**: Controls whether patterns can match nodes with additional
inputs beyond those specified
- **Default**: `False` (exact input matching)
- **Example**: Matching `Conv` operations that may have optional bias
inputs

```python
def conv_pattern(op, input, weight):
    # Matches Conv with 2 or 3 inputs (weight + optional bias)
    return op.Conv(input, weight, _allow_other_inputs=True)
```

### `_domain` option  
- **File**: `docs/tutorial/rewriter/domain_option.md`
- **Purpose**: Specifies operator domains for pattern matching and
replacement
- **Use cases**: Domain-specific rewrites, migrating between operator
domains
- **Example**: Targeting operations from specific domains like
"com.microsoft"

```python
def custom_relu_pattern(op, input):
    # Only matches Relu from custom domain
    return op.Relu(input, _domain="custom.domain")
```

### `_outputs` option
- **File**: `docs/tutorial/rewriter/outputs_option.md` 
- **Purpose**: Specifies number and names of operation outputs
- **Formats**: Integer count (`_outputs=2`) or named list
(`_outputs=["first", "second"]`)
- **Example**: Handling multi-output operations like `Split`

```python
def split_pattern(op, input):
    # Matches Split operations with exactly 2 outputs
    return op.Split(input, num_outputs=2, axis=0, _outputs=2)
```

### Enhanced `_allow_other_attributes` documentation
- **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting)
- **Already documented**: Controls whether patterns match nodes with
additional attributes
- **Default**: `True` (allows extra attributes)

## Documentation Structure Improvements

- Added "Pattern Options" section to main rewriter documentation
- Integrated all option docs into the tutorial flow
- Created working code examples for each option
- Followed existing documentation patterns and style
- All examples compile and run successfully
- Documentation builds correctly with Sphinx

The documentation now provides complete coverage of all rewriter pattern
options with practical examples showing real-world usage patterns.

Fixes microsoft#2405.

> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `docs.python.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `docs.scipy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `matplotlib.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `numpy.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnx.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `onnxruntime.ai`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
> - `pytorch.org`
> - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns
block)
> - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 `
(dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to my [firewall allow
list](https://gh.io/copilot/firewall-config)
>
> </details>



<!-- START COPILOT CODING AGENT TIPS -->
---

💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.

---------

Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Co-authored-by: justinchuby <[email protected]>
Co-authored-by: gramalingam <[email protected]>
In onnx2script, nan, inf etc. were converted to plain text, which causes
evaluation to fail because they don't exist in the script. I updated the
logic to replace them with np. values.

---------

Signed-off-by: Justin Chu <[email protected]>
Simplify implementation for `aten_chunk` and allow it to work on all
data types.

Original author: @xadupre 
Updated: Conditionally use the new implementation when torch>=2.7

---------

Signed-off-by: Justin Chu <[email protected]>
Co-authored-by: Xavier Dupré <[email protected]>
Copy link

codecov bot commented Aug 1, 2025

Codecov Report

❌ Patch coverage is 41.01509% with 430 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.01%. Comparing base (da23d76) to head (d5383f0).

Files with missing lines Patch % Lines
...ipt/rewriter/phi4_mini_reasoning_post_processor.py 18.24% 345 Missing ⚠️
onnxscript/rewriter/ort_fusions/longrope_gqa.py 65.42% 65 Missing ⚠️
onnxscript/rewriter/ort_fusions/gqa.py 83.19% 20 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2473      +/-   ##
==========================================
- Coverage   69.81%   69.01%   -0.81%     
==========================================
  Files         209      211       +2     
  Lines       25313    25978     +665     
  Branches     2525     2612      +87     
==========================================
+ Hits        17673    17928     +255     
- Misses       6762     7175     +413     
+ Partials      878      875       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@tadani3 tadani3 marked this pull request as draft August 1, 2025 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

4 participants