Skip to content

[CB] Ensure parallel decoding test passes using FA#43277

Merged
vasqu merged 2 commits intomainfrom
cb-flakyness
Jan 14, 2026
Merged

[CB] Ensure parallel decoding test passes using FA#43277
vasqu merged 2 commits intomainfrom
cb-flakyness

Conversation

@remi-or
Copy link
Copy Markdown
Collaborator

@remi-or remi-or commented Jan 14, 2026

This PR fixes a flaky test in continuous batching: the test was checking two streams of parallel decoding matched, but used sdpa as the attention implementation. But sdpa is not batch invariant, so despite a well formed attention mask, the test failed. With flash-attention, it is not the case anymore. It tested the following snippet:

# Imports and implementations [...]
 
num_heads, head_dim, kv_per_query = 32, 64, 32

failures = 0
iters = 1000
for i in range(iters):

    # Create two identical query tokens
    Q = torch.randn(1, num_heads, 2, head_dim)
    Q[:, :, 1] = Q[:, :, 0]
    # Create identical KV cache for both queries
    K_or_V = torch.randn(1, num_heads, 2 * kv_per_query, head_dim)
    K_or_V[:, :, kv_per_query:] = K_or_V[:, :, :kv_per_query]

    # Attention mask to avoid cross-query interactions
    attention_mask = torch.zeros(1, 2, 2 * kv_per_query)
    attention_mask[0, 0, :kv_per_query] = 1
    attention_mask[0, 1, kv_per_query:] = 1

    # Use SDPA to compute attention
    output = compute_attention(Q, K_or_V, K_or_V, attention_mask)

    # Check using absolute and relative differences
    query_0_output, query_1_output = output[:, :, 0], output[:, :, 1]
    is_ok = torch.isclose(query_0_output, query_1_output).all().item()
    failures +=  1 - is_ok

and the number of failure for each implementation was:

  • sdpa: 597
  • eager: 156
  • kernels-community/flash-attn: 0

So I went with flash attention for the test.

@remi-or remi-or requested a review from vasqu January 14, 2026 08:13
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx also double checked locally, works 👍

# with torch 2.9. This might be due to a GPU diff, so test might be flaky on the CI which runs on A10.
@parameterized.expand([True, False])
@require_torch_accelerator
@require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the comment

@vasqu vasqu merged commit 61317f5 into main Jan 14, 2026
17 checks passed
@vasqu vasqu deleted the cb-flakyness branch January 14, 2026 13:32
SangbumChoi pushed a commit to SangbumChoi/transformers that referenced this pull request Jan 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants