Skip to content

StaticSlidingWindowLayer triggers torch.compile/dynamo recompilations every decode step #44609

@jazpurTT

Description

@jazpurTT

Feature request

Summary

Could there be added support for StaticSlidingWindowLayer to be fully compatible with torch.compile() by avoiding it's dynamic control flow?

When using a StaticCache with StaticSlidingWindowLayer (e.g. GPT-OSS, Mistral) and torch.compile(), the compiled graph recompiles on every decode step. The recompile limit is hit quickly and generation is effectively uncompiled or very slow.

StaticSlidingWindowLayer keeps cumulative_length as a Python int and updates it each step self.cumulative_length += key_states.shape[-2]. It is used in:

  • update – for branching (is_full, cumulative_length + key_states.shape[-2] > self.max_cache_len, etc.).
  • get_mask_sizes – which returns (kv_length, kv_offset) as Python ints derived from self.cumulative_length.

TorchDynamo treats self.cumulative_length as a constant and installs guards on their concrete values. Each decode step changes cumulative_length (e.g. 71 → 72 → 73…), so the guards fail and the graph recompiles for every model forward pass.

Example guard failures:

Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
    triggered by the following guard failure(s):
    - 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
    triggered by the following guard failure(s):
    - 0/2: kwargs['past_key_values'].layers[0].cumulative_length == 72  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
Recompiling function wrapper in /lib/python3.12/site-packages/transformers/utils/generic.py:912
    triggered by the following guard failure(s):
    - 0/3: kwargs['past_key_values'].layers[0].cumulative_length == 73  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/2: kwargs['past_key_values'].layers[0].cumulative_length == 72  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/1: kwargs['past_key_values'].layers[0].cumulative_length == 71  # is_full = self.cumulative_length >= self.max_cache_len  # transformers/cache_utils.py:462 in get_mask_sizes
    - 0/0: tensor 'kwargs['input_ids']' size mismatch at index 1. expected 71, actual 1
...
torch._dynamo hit config.recompile_limit (8)

Reproduction / Example

The following snippet (from a GPT-OSS 20B generation example) shows the scenario where recompilation occurs. Every decode step triggers a recompile:

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache

model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    attn_implementation="eager",
)
# config has sliding_window and layer_types with "sliding_attention" → StaticSlidingWindowLayer
cache_config = model.config

static_cache = StaticCache(
    config=cache_config,
    max_batch_size=batch_size,
    max_cache_len=max_cache_len,
    device="cpu",
    dtype=torch.bfloat16,
)
static_cache.early_initialization(
    batch_size=batch_size,
    num_heads=model.config.num_key_value_heads,
    head_dim=model.config.head_dim,
    dtype=torch.bfloat16,
    device="cpu",
)

compiled_model = torch.compile(model)

input_args = {
    "input_ids": input_ids,
    "past_key_values": static_cache,
    "cache_position": torch.arange(71),
    "use_cache": True,
    "attention_mask": attention_mask,
}
output = compiled_model(**input_args)

for step in range(max_tokens_to_generate - 1):
    input_args["input_ids"] = next_token_id.unsqueeze(-1)           # shape (1, 1)
    input_args["cache_position"] = torch.tensor([[input_args["cache_position"][-1].item() + 1]])
    output = compiled_model(**input_args)   # cumulative_length changed → guard fail → recompile

The second and subsequent compiled_model(**input_args) calls see a different past_key_values.layers[i].cumulative_length each time and torch dynamo recompiles the graph each time.

Environment

  • transformers 4.57.1 but I am pretty sure same issue would happen on the latest main branch.
  • torch 2.9.0 with torch.compile(model)
  • Models that use StaticSlidingWindowyLayer (e.g. openai/gpt-oss-20b with default layer_types)

Motivation

Sliding Attention is a common way to run large context models with bounded memory and latency (e.g. Mistral, GPT-OSS). With StaticSlidingWindowLayer, decode is effectively uncompilable: every new token changes cumulative_length, torch._dynamo recompiles, and the recompile limit is reached after a few dozen steps. That removes most of the benefit of compilation for long generations and makes it hard to use sliding-window caches in production with torch.compile. Fixing this minimizes the amount of different decode graphs to run, so sliding-window + compile is viable.

Your contribution

I have written a workaround to be able to compile and run openai/gpt-oss-20b for 180+ forward passes with only two compiled graphs, one for the pre-fill pass and one for decode passes:

  1. It replaces StaticSlidingWindowLayer with my own custom class where:
    • Replaces the original update method with a branchless version that always rolls the KV buffer left by n positions and writes new tokens to the rightmost slots. This produces a single straight-line graph with no mutable Python state, so torch.dynamo never sees a reason to recompile.
    • get_mask_sizes returns constants (max_cache_len, 0) instead of values derived from cache_position and self.cummulative_length at runtime.
  2. It replaces create_sliding_window_causal_mask with a custom mask building function where:
    • It builds the mask using the right-aligned mapping: it computes real_pos = total_tokens_seen - sliding_window + slot_index to figure out which buffer slots hold which sequence positions, then applies validity, causality, and window constraints from there. This masking understands the always-roll layout and the new get_mask_sizes const return.

This solution was made for openai/gpt-oss-20b but I have also tested it on mistralai/Ministral-8B-Instruct-2410 and it worked, but I assume the solution as is would not be compatible with all models using sliding attention, but this could be a starting point to support a non-dynamic StaticSlidingWindowLayer for multi-run torch.compile compatibility.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions