-
Notifications
You must be signed in to change notification settings - Fork 32.4k
Description
Feature request
Summary
Could there be added support for StaticSlidingWindowLayer to be fully compatible with torch.compile() by avoiding it's dynamic control flow?
When using a StaticCache with StaticSlidingWindowLayer (e.g. GPT-OSS, Mistral) and torch.compile(), the compiled graph recompiles on every decode step. The recompile limit is hit quickly and generation is effectively uncompiled or very slow.
StaticSlidingWindowLayer keeps cumulative_length as a Python int and updates it each step self.cumulative_length += key_states.shape[-2]. It is used in:
update– for branching (is_full, cumulative_length + key_states.shape[-2] > self.max_cache_len, etc.).get_mask_sizes– which returns (kv_length, kv_offset) as Python ints derived from self.cumulative_length.
TorchDynamo treats self.cumulative_length as a constant and installs guards on their concrete values. Each decode step changes cumulative_length (e.g. 71 → 72 → 73…), so the guards fail and the graph recompiles for every model forward pass.
Example guard failures:
Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
triggered by the following guard failure(s):
- 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
triggered by the following guard failure(s):
- 0/2: kwargs['past_key_values'].layers[0].cumulative_length == 72 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
triggered by the following guard failure(s):
- 0/3: kwargs['past_key_values'].layers[0].cumulative_length == 73 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/2: kwargs['past_key_values'].layers[0].cumulative_length == 72 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71 # is_full = self.cumulative_length >= self.max_cache_len # transformers/cache_utils.py:462 in get_mask_sizes
- 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
...
torch._dynamo hit config.recompile_limit (8)Reproduction / Example
The following snippet (from a GPT-OSS 20B generation example) shows the scenario where recompilation occurs. Every decode step triggers a recompile:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
model = AutoModelForCausalLM.from_pretrained(
"openai/gpt-oss-20b",
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
attn_implementation="eager",
)
# config has sliding_window and layer_types with "sliding_attention" → StaticSlidingWindowLayer
cache_config = model.config
static_cache = StaticCache(
config=cache_config,
max_batch_size=batch_size,
max_cache_len=max_cache_len,
device="cpu",
dtype=torch.bfloat16,
)
static_cache.early_initialization(
batch_size=batch_size,
num_heads=model.config.num_key_value_heads,
head_dim=model.config.head_dim,
dtype=torch.bfloat16,
device="cpu",
)
compiled_model = torch.compile(model)
input_args = {
"input_ids": input_ids,
"past_key_values": static_cache,
"cache_position": torch.arange(71),
"use_cache": True,
"attention_mask": attention_mask,
}
output = compiled_model(**input_args)
for step in range(max_tokens_to_generate - 1):
input_args["input_ids"] = next_token_id.unsqueeze(-1) # shape (1, 1)
input_args["cache_position"] = torch.tensor([[input_args["cache_position"][-1].item() + 1]])
output = compiled_model(**input_args) # cumulative_length changed → guard fail → recompileThe second and subsequent compiled_model(**input_args) calls see a different past_key_values.layers[i].cumulative_length each time and torch dynamo recompiles the graph each time.
Environment
- transformers
4.57.1but I am pretty sure same issue would happen on the latest main branch. - torch 2.9.0 with
torch.compile(model) - Models that use StaticSlidingWindowyLayer (e.g. openai/gpt-oss-20b with default layer_types)
Motivation
Sliding Attention is a common way to run large context models with bounded memory and latency (e.g. Mistral, GPT-OSS). With StaticSlidingWindowLayer, decode is effectively uncompilable: every new token changes cumulative_length, torch._dynamo recompiles, and the recompile limit is reached after a few dozen steps. That removes most of the benefit of compilation for long generations and makes it hard to use sliding-window caches in production with torch.compile. Fixing this minimizes the amount of different decode graphs to run, so sliding-window + compile is viable.
Your contribution
I have written a workaround to be able to compile and run openai/gpt-oss-20b for 180+ forward passes with only two compiled graphs, one for the pre-fill pass and one for decode passes:
- It replaces
StaticSlidingWindowLayerwith my own custom class where:- Replaces the original
updatemethod with a branchless version that always rolls the KV buffer left by n positions and writes new tokens to the rightmost slots. This produces a single straight-line graph with no mutable Python state, so torch.dynamo never sees a reason to recompile. get_mask_sizesreturns constants(max_cache_len, 0)instead of values derived fromcache_positionandself.cummulative_lengthat runtime.
- Replaces the original
- It replaces
create_sliding_window_causal_maskwith a custom mask building function where:- It builds the mask using the right-aligned mapping: it computes real_pos = total_tokens_seen - sliding_window + slot_index to figure out which buffer slots hold which sequence positions, then applies validity, causality, and window constraints from there. This masking understands the always-roll layout and the new
get_mask_sizesconst return.
- It builds the mask using the right-aligned mapping: it computes real_pos = total_tokens_seen - sliding_window + slot_index to figure out which buffer slots hold which sequence positions, then applies validity, causality, and window constraints from there. This masking understands the always-roll layout and the new
This solution was made for openai/gpt-oss-20b but I have also tested it on mistralai/Ministral-8B-Instruct-2410 and it worked, but I assume the solution as is would not be compatible with all models using sliding attention, but this could be a starting point to support a non-dynamic StaticSlidingWindowLayer for multi-run torch.compile compatibility.