Skip to content

Commit 362fa37

Browse files
authored
[test] update test_past_key_values_format (#37614)
allow custom shapes
1 parent 1cd110c commit 362fa37

19 files changed

+133
-165
lines changed

tests/generation/test_utils.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,92 +1539,133 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
15391539
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
15401540

15411541
@pytest.mark.generate
1542-
def test_past_key_values_format(self):
1543-
# Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
1544-
# standard KV cache format is important for a consistent API (and for advanced generation methods).
1542+
def test_past_key_values_format(self, custom_all_cache_shapes=None):
1543+
"""
1544+
Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
1545+
expected cache shapes.
1546+
Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
1547+
"""
15451548
for model_class in self.all_generative_model_classes:
15461549
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
15471550

1548-
# If it doesn't support cache, pass the test
1551+
# 1. If it doesn't support cache, skip the test
15491552
if not hasattr(config.get_text_config(), "use_cache"):
15501553
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
15511554

15521555
model = model_class(config).to(torch_device)
1556+
model = model.eval()
15531557
if "use_cache" not in inputs:
15541558
inputs["use_cache"] = True
15551559
outputs = model(**inputs)
15561560

1557-
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
15581561
if "past_key_values" not in outputs:
15591562
self.skipTest(reason="This model doesn't return `past_key_values`")
15601563

1564+
# 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
1565+
past_kv = outputs["past_key_values"]
1566+
is_legacy_cache = not isinstance(past_kv, Cache)
1567+
15611568
text_config = config.get_text_config()
1562-
num_hidden_layers = (
1569+
num_decoder_layers = (
15631570
getattr(text_config, "decoder_layers", None)
15641571
or getattr(text_config, "num_decoder_layers", None)
15651572
or text_config.num_hidden_layers
15661573
)
1567-
num_attention_heads = getattr(text_config, "decoder_attention_heads", text_config.num_attention_heads)
1568-
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
1569-
per_head_embed_dim = embed_dim // num_attention_heads
1570-
1571-
# some models have different num-head for query vs key/value so we need to assign correct value
1572-
# BUT only after `per_head_embed_dim` is set
1573-
num_attention_heads = (
1574-
text_config.num_key_value_heads
1575-
if getattr(text_config, "num_key_value_heads", None) is not None
1576-
else num_attention_heads
1577-
)
15781574

1579-
past_kv = outputs["past_key_values"]
1580-
self.assertEqual(len(past_kv), num_hidden_layers)
1575+
if custom_all_cache_shapes is None:
1576+
num_query_attention_heads = getattr(
1577+
text_config, "decoder_attention_heads", text_config.num_attention_heads
1578+
)
1579+
embed_dim = getattr(text_config, "d_model", text_config.hidden_size)
1580+
per_head_embed_dim = embed_dim // num_query_attention_heads
1581+
num_key_value_heads = (
1582+
text_config.num_key_value_heads
1583+
if getattr(text_config, "num_key_value_heads", None) is not None
1584+
else num_query_attention_heads
1585+
)
1586+
if config.is_encoder_decoder:
1587+
encoder_num_attention_heads = (
1588+
text_config.encoder_attention_heads
1589+
if hasattr(text_config, "encoder_attention_heads")
1590+
else text_config.num_attention_heads
1591+
)
1592+
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
1593+
batch_size, seq_length = inputs["decoder_input_ids"].shape
1594+
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
1595+
# autoregressive generation, we're keeping the test general and not checking the 3rd dim
1596+
default_cross_attention_shape = (
1597+
batch_size,
1598+
encoder_num_attention_heads,
1599+
encoder_per_head_embed_dim,
1600+
)
1601+
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
1602+
all_cache_shapes = [
1603+
[
1604+
default_self_attention_shape,
1605+
default_self_attention_shape,
1606+
default_cross_attention_shape,
1607+
default_cross_attention_shape,
1608+
]
1609+
for _ in range(num_decoder_layers)
1610+
]
1611+
else:
1612+
batch_size, seq_length = inputs["input_ids"].shape
1613+
default_self_attention_shape = (batch_size, num_key_value_heads, seq_length, per_head_embed_dim)
1614+
all_cache_shapes = [
1615+
[default_self_attention_shape, default_self_attention_shape] for _ in range(num_decoder_layers)
1616+
]
15811617

1582-
# Encoder-Decoder checks
1618+
else:
1619+
all_cache_shapes = custom_all_cache_shapes
1620+
1621+
# 3. Check cache shapes
1622+
# 3.1. Encoder-Decoder checks
15831623
if config.is_encoder_decoder:
1584-
# encoder-decoder models usually don't have text config
1585-
# below is needed only for Pix2Struct which we cannot modify now due to BC
1586-
config = config.get_text_config()
1587-
encoder_num_attention_heads = (
1588-
config.encoder_attention_heads
1589-
if hasattr(config, "encoder_attention_heads")
1590-
else config.num_attention_heads
1624+
num_cache_decoder_layers = (
1625+
len(past_kv) if is_legacy_cache else len(past_kv.self_attention_cache.key_cache)
15911626
)
1592-
encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
1593-
batch_size, seq_length = inputs["decoder_input_ids"].shape
1594-
for i in range(num_hidden_layers):
1595-
self.assertEqual(len(past_kv[i]), 4) # K V for the decoder + K V for the encoder = 4
1596-
self.assertEqual(
1597-
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
1627+
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
1628+
1629+
for i in range(num_decoder_layers):
1630+
if is_legacy_cache:
1631+
self.assertEqual(len(past_kv[0]), 4) # legacy check: confirm number of elements in tuple
1632+
1633+
# Self attention
1634+
self_attention_layer_key_cache = (
1635+
past_kv[i][0] if is_legacy_cache else past_kv.self_attention_cache.key_cache[i]
15981636
)
1599-
self.assertEqual(
1600-
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
1637+
self_attention_layer_value_cache = (
1638+
past_kv[i][1] if is_legacy_cache else past_kv.self_attention_cache.value_cache[i]
16011639
)
1602-
# The sequence length for the encoder K V depends on the model. Since it is not manipulated in
1603-
# autoregressive generation, I'm keeping the test general and not checking the 3rd dim
1604-
self.assertEqual(
1605-
(past_kv[i][2].shape[0], past_kv[i][2].shape[1], past_kv[i][2].shape[3]),
1606-
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
1640+
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
1641+
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
1642+
1643+
# Cross attention (ignore 3rd dim, see default shape preparation)
1644+
cross_attention_layer_key_cache = (
1645+
past_kv[i][2] if is_legacy_cache else past_kv.cross_attention_cache.key_cache[i]
16071646
)
1608-
self.assertEqual(
1609-
(past_kv[i][3].shape[0], past_kv[i][3].shape[1], past_kv[i][3].shape[3]),
1610-
(batch_size, encoder_num_attention_heads, encoder_per_head_embed_dim),
1647+
cross_attention_layer_value_cache = (
1648+
past_kv[i][3] if is_legacy_cache else past_kv.cross_attention_cache.value_cache[i]
16111649
)
1650+
cross_attention_layer_key_cache = cross_attention_layer_key_cache[:, :, 0, :]
1651+
cross_attention_layer_value_cache = cross_attention_layer_value_cache[:, :, 0, :]
1652+
self.assertEqual(cross_attention_layer_key_cache.shape, all_cache_shapes[i][2])
1653+
self.assertEqual(cross_attention_layer_value_cache.shape, all_cache_shapes[i][3])
16121654

1613-
# Decoder-only checks
1655+
# 3.2. Decoder-only checks
16141656
else:
1615-
# TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
1616-
# tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
1617-
# tests use it)
1618-
key = "input_ids" if "input_ids" in inputs else "pixel_values"
1619-
batch_size, seq_length = inputs[key].shape
1620-
for i in range(num_hidden_layers):
1621-
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
1622-
self.assertEqual(
1623-
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
1624-
)
1625-
self.assertEqual(
1626-
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
1627-
)
1657+
num_cache_decoder_layers = len(past_kv) if is_legacy_cache else len(past_kv.key_cache)
1658+
self.assertEqual(num_cache_decoder_layers, num_decoder_layers)
1659+
1660+
for i in range(num_decoder_layers):
1661+
if is_legacy_cache:
1662+
self.assertEqual(len(past_kv[0]), 2) # legacy check: confirm number of elements in tuple
1663+
1664+
# Self attention
1665+
self_attention_layer_key_cache = past_kv[i][0] if is_legacy_cache else past_kv.key_cache[i]
1666+
self_attention_layer_value_cache = past_kv[i][1] if is_legacy_cache else past_kv.value_cache[i]
1667+
self.assertEqual(self_attention_layer_key_cache.shape, all_cache_shapes[i][0])
1668+
self.assertEqual(self_attention_layer_value_cache.shape, all_cache_shapes[i][1])
16281669

16291670
@pytest.mark.generate
16301671
@parameterized.expand([("greedy", 1), ("beam search", 2)])

tests/models/deepseek_v3/test_modeling_deepseek_v3.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,23 @@ def test_model_rope_scaling(self):
429429
with self.assertRaises(AssertionError):
430430
torch.testing.assert_close(yarn_sin_long, original_sin_long)
431431

432-
@unittest.skip(reason="Deepseek-V3 uses MLA on all models so the KV cache is a non standard format")
433432
def test_past_key_values_format(self):
434-
pass
433+
"""
434+
Overwritting to pass the expected cache shapes (Deepseek-V3 uses MLA so the cache shapes are non-standard)
435+
"""
436+
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
437+
batch_size, seq_length = inputs["input_ids"].shape
438+
# difference: last dim
439+
k_embed_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
440+
v_embed_dim = config.v_head_dim
441+
self_attention_key_cache_shape = (batch_size, config.num_key_value_heads, seq_length, k_embed_dim)
442+
self_attention_value_cache_shape = (batch_size, config.num_key_value_heads, seq_length, v_embed_dim)
443+
# build the full cache shapes
444+
num_hidden_layers = config.num_hidden_layers
445+
all_cache_shapes = [
446+
[self_attention_key_cache_shape, self_attention_value_cache_shape] for _ in range(num_hidden_layers)
447+
]
448+
super().test_past_key_values_format(custom_all_cache_shapes=all_cache_shapes)
435449

436450
@require_torch_sdpa
437451
@slow

tests/models/falcon/test_modeling_falcon.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -264,51 +264,6 @@ def test_falcon_sequence_classification_model_for_multi_label(self):
264264
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
265265
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
266266

267-
def test_past_key_values_format(self):
268-
# Falcon can have different numbers of KV-heads than the number of query heads, so we need
269-
# to override this test to use the right head counts.
270-
for model_class in self.all_generative_model_classes:
271-
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
272-
273-
# If it doesn't support cache, pass the test
274-
if not hasattr(config, "use_cache"):
275-
self.skipTest(reason="Model does not support cache")
276-
277-
model = model_class(config).to(torch_device)
278-
if "use_cache" not in inputs:
279-
inputs["use_cache"] = True
280-
outputs = model(**inputs)
281-
282-
# If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
283-
if "past_key_values" not in outputs:
284-
self.skipTest(reason="Model does not return past_key_values")
285-
286-
num_hidden_layers = (
287-
getattr(config, "decoder_layers", None)
288-
or getattr(config, "num_decoder_layers", None)
289-
or config.num_hidden_layers
290-
)
291-
num_attention_heads = getattr(config, "num_kv_heads", config.num_attention_heads)
292-
embed_dim = getattr(config, "d_model", config.hidden_size)
293-
per_head_embed_dim = embed_dim // num_attention_heads
294-
295-
past_kv = outputs["past_key_values"]
296-
self.assertEqual(len(past_kv), num_hidden_layers)
297-
298-
batch_size, seq_length = inputs["input_ids"].shape
299-
for i in range(num_hidden_layers):
300-
if config.new_decoder_architecture:
301-
num_attention_heads = config.num_attention_heads
302-
elif config.multi_query:
303-
num_attention_heads = 1
304-
self.assertEqual(len(past_kv[0]), 2) # K V for the decoder = 2
305-
self.assertEqual(
306-
past_kv[i][0].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
307-
)
308-
self.assertEqual(
309-
past_kv[i][1].shape, (batch_size, num_attention_heads, seq_length, per_head_embed_dim)
310-
)
311-
312267
@parameterized.expand([("linear",), ("dynamic",)])
313268
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling_from_config with Llama->Falcon
314269
def test_model_rope_scaling_from_config(self, scaling_type):

tests/models/gemma/test_modeling_gemma.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,10 +296,6 @@ def test_Gemma_token_classification_model(self):
296296
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
297297
)
298298

299-
@unittest.skip(reason="Gemma uses GQA on all models so the KV cache is a non standard format")
300-
def test_past_key_values_format(self):
301-
pass
302-
303299
@require_flash_attn
304300
@require_torch_gpu
305301
@pytest.mark.flash_attn_test

tests/models/glm/test_modeling_glm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,6 @@ def test_Glm_token_classification_model(self):
264264
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
265265
)
266266

267-
@unittest.skip(reason="Glm uses GQA on all models so the KV cache is a non standard format")
268-
def test_past_key_values_format(self):
269-
pass
270-
271267
@is_flaky()
272268
def test_custom_4d_attention_mask(self):
273269
"""Overwrite the common test to use atol=1e-3 instead of 1e-4. Can still rarely fail, thus flaky."""

tests/models/got_ocr2/test_modeling_got_ocr2.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,6 @@ def test_inputs_embeds_matches_input_ids(self):
222222
def test_generate_from_inputs_embeds_with_static_cache(self):
223223
pass
224224

225-
@unittest.skip(
226-
reason="GotOcr2's language backbone is Qwen2 which uses GQA so the KV cache is a non standard format"
227-
)
228-
def test_past_key_values_format(self):
229-
pass
230-
231225
@unittest.skip("FlashAttention only support fp16 and bf16 data type")
232226
def test_flash_attn_2_fp32_ln(self):
233227
pass

tests/models/imagegpt/test_modeling_imagegpt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def test_forward_signature(self):
319319
def test_left_padding_compatibility(self):
320320
pass
321321

322+
@unittest.skip(reason="Model inputs don't fit test pattern") # and it's not used enough to be worth fixing :)
323+
def test_past_key_values_format(self):
324+
pass
325+
322326

323327
# We will verify our results on an image of cute cats
324328
def prepare_img():

tests/models/jetmoe/test_modeling_jetmoe.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,10 +251,6 @@ def test_jetmoe_sequence_classification_model_for_multi_label(self):
251251
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
252252
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
253253

254-
@unittest.skip(reason="JetMoe uses MoA on all models so the KV cache is a non standard format")
255-
def test_past_key_values_format(self):
256-
pass
257-
258254
@require_flash_attn
259255
@require_torch_gpu
260256
@pytest.mark.flash_attn_test

tests/models/mistral/test_modeling_mistral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,6 @@ def test_Mistral_token_classification_model(self):
292292
(self.model_tester.batch_size, self.model_tester.seq_length, self.model_tester.num_labels),
293293
)
294294

295-
@unittest.skip(reason="Mistral uses GQA on all models so the KV cache is a non standard format")
296-
def test_past_key_values_format(self):
297-
pass
298-
299295
@require_flash_attn
300296
@require_torch_gpu
301297
@pytest.mark.flash_attn_test

tests/models/mistral/test_modeling_tf_mistral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,6 @@ def test_Mistral_sequence_classification_model_for_multi_label(self):
324324
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
325325
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
326326

327-
@unittest.skip("Mistral uses GQA on all models so the KV cache is a non standard format")
328-
def test_past_key_values_format(self):
329-
pass
330-
331327
@unittest.skip("Vocab resizing is not supported")
332328
def test_save_load_after_resize_token_embeddings(self):
333329
pass

0 commit comments

Comments
 (0)