-
Notifications
You must be signed in to change notification settings - Fork 81
Added LongRoPe Model Causal Mask Pattern Fusion #2473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…xscript into longrope_causal_mask
""" | ||
Pattern for LongRoPe GQA Causal Mask. | ||
This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
""" | ||
Pattern for LongRoPe GQA Causal Mask. | ||
This pattern computes the causal mask for Group Query Attention with LongRoPe. | ||
It constructs the mask based on input_ids and past_kv_cache, and handles the |
Check notice
Code scanning / CodeQL
Unused local variable Note
mask_key = _get_mask_key(attention_mask) | ||
|
||
if mask_key in self._mask_cache: | ||
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
mask_key = _get_mask_key(attention_mask) | ||
|
||
if mask_key in self._mask_cache: | ||
total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] |
Check notice
Code scanning / CodeQL
Unused local variable Note
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
import onnx |
Check notice
Code scanning / CodeQL
Unused import Note
# -------------------------------------------------------------------------- | ||
import onnx | ||
from onnxscript import ir | ||
import onnx.helper |
Check notice
Code scanning / CodeQL
Unused import Note
cache_length = self.rotemb_attrs["cache_length"] | ||
position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) | ||
|
||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
with torch.autocast(device_type=device_type, enabled=False): | ||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) | ||
emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) | ||
cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] | ||
|
||
inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim | ||
inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
if "rescale_inv_freq" in self.rotemb_attrs: | ||
inv_freq = self.make_inv_freq_rescaled(inv_freq) | ||
|
||
return inv_freq, attention_factor |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
…t#2465) Signed-off-by: Justin Chu <[email protected]>
Provide a way to indicate that a pattern-variable can match successfully against a None-valued input. Cleanup current handling which was inconsistent in one place. Add test cases. --------- Signed-off-by: Ganesan Ramalingam <[email protected]> Co-authored-by: Copilot <[email protected]>
This PR adds comprehensive documentation for the rewriter pattern options that were previously undocumented. The rewriter pattern system supports four key options for controlling pattern matching and replacement behavior: ## New Documentation Added ### `_allow_other_inputs` option - **File**: `docs/tutorial/rewriter/allow_other_inputs.md` - **Purpose**: Controls whether patterns can match nodes with additional inputs beyond those specified - **Default**: `False` (exact input matching) - **Example**: Matching `Conv` operations that may have optional bias inputs ```python def conv_pattern(op, input, weight): # Matches Conv with 2 or 3 inputs (weight + optional bias) return op.Conv(input, weight, _allow_other_inputs=True) ``` ### `_domain` option - **File**: `docs/tutorial/rewriter/domain_option.md` - **Purpose**: Specifies operator domains for pattern matching and replacement - **Use cases**: Domain-specific rewrites, migrating between operator domains - **Example**: Targeting operations from specific domains like "com.microsoft" ```python def custom_relu_pattern(op, input): # Only matches Relu from custom domain return op.Relu(input, _domain="custom.domain") ``` ### `_outputs` option - **File**: `docs/tutorial/rewriter/outputs_option.md` - **Purpose**: Specifies number and names of operation outputs - **Formats**: Integer count (`_outputs=2`) or named list (`_outputs=["first", "second"]`) - **Example**: Handling multi-output operations like `Split` ```python def split_pattern(op, input): # Matches Split operations with exactly 2 outputs return op.Split(input, num_outputs=2, axis=0, _outputs=2) ``` ### Enhanced `_allow_other_attributes` documentation - **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting) - **Already documented**: Controls whether patterns match nodes with additional attributes - **Default**: `True` (allows extra attributes) ## Documentation Structure Improvements - Added "Pattern Options" section to main rewriter documentation - Integrated all option docs into the tutorial flow - Created working code examples for each option - Followed existing documentation patterns and style - All examples compile and run successfully - Documentation builds correctly with Sphinx The documentation now provides complete coverage of all rewriter pattern options with practical examples showing real-world usage patterns. Fixes microsoft#2405. > [!WARNING] > > <details> > <summary>Firewall rules blocked me from connecting to one or more addresses</summary> > > #### I tried to connect to the following addresses, but was blocked by firewall rules: > > - `docs.python.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `docs.scipy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `matplotlib.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `numpy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnx.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnxruntime.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `pytorch.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > > If you need me to access, download, or install something from one of these locations, you can either: > > - Configure [Actions setup steps](https://gh.io/copilot/actions-setup-steps) to set up my environment, which run before the firewall is enabled > - Add the appropriate URLs or hosts to my [firewall allow list](https://gh.io/copilot/firewall-config) > > </details> <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: justinchuby <[email protected]> Co-authored-by: gramalingam <[email protected]>
In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. --------- Signed-off-by: Justin Chu <[email protected]>
Simplify implementation for `aten_chunk` and allow it to work on all data types. Original author: @xadupre Updated: Conditionally use the new implementation when torch>=2.7 --------- Signed-off-by: Justin Chu <[email protected]> Co-authored-by: Xavier Dupré <[email protected]>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2473 +/- ##
==========================================
- Coverage 69.81% 69.01% -0.81%
==========================================
Files 209 211 +2
Lines 25313 25978 +665
Branches 2525 2612 +87
==========================================
+ Hits 17673 17928 +255
- Misses 6762 7175 +413
+ Partials 878 875 -3 ☔ View full report in Codecov by Sentry. |
This PR introduces a specialized LongRoPe (Long Range Rotary Position Embedding) GQA (Group Query Attention) causal mask fusion rule specifically designed for Phi-4-mini-reasoning and similar models. The implementation optimizes attention mask computation for models using sliding window attention with LongRoPe position embeddings.
New LongRoPeGQACausalMask Class
Advanced Mask Computation
Note: This PR is meant to replace #2461 by introducing the requested changes.