Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/generation/test_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from transformers.testing_utils import (
Expectations,
require_deterministic_for_xpu,
require_flash_attn,
require_torch_accelerator,
slow,
torch_device,
Expand Down Expand Up @@ -601,10 +602,9 @@ def test_block_sharing_with_hybrid_model(self) -> None:

return self._test_block_sharing(model_id, num_layer_groups, input_msg, expected_generated_tokens)

# The test always passes on H100 with torch 2.9, but only passed case 0 on A100 with torch 2.6 and fails on A100
# with torch 2.9. This might be due to a GPU diff, so test might be flaky on the CI which runs on A10.
@parameterized.expand([True, False])
@require_torch_accelerator
@require_flash_attn # otherwise the test can fail because attention bias has a very slight impact on SDPA and eager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the comment

def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
Expand All @@ -616,7 +616,7 @@ def test_num_return_sequences(self, allow_block_sharing: bool) -> None:
input_ids = [(x if isinstance(x, list) else x["input_ids"]) for x in tokenized]

# Generation with continuous batching
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="sdpa")
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")
model = model.to(torch_device).eval()
model.generation_config.max_new_tokens = 30
model.generation_config.do_sample = False
Expand Down