@@ -1539,92 +1539,133 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
15391539 torch .testing .assert_close (next_logits_wo_padding , next_logits_with_padding , rtol = 1e-5 , atol = 1e-5 )
15401540
15411541 @pytest .mark .generate
1542- def test_past_key_values_format (self ):
1543- # Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test. Having a
1544- # standard KV cache format is important for a consistent API (and for advanced generation methods).
1542+ def test_past_key_values_format (self , custom_all_cache_shapes = None ):
1543+ """
1544+ Test that the KV cache is formatted correctly. Exceptions need to explicitly overwrite this test, or pass the
1545+ expected cache shapes.
1546+ Having a standard KV cache format is important for a consistent API (and for advanced generation methods).
1547+ """
15451548 for model_class in self .all_generative_model_classes :
15461549 config , inputs = self .model_tester .prepare_config_and_inputs_for_common ()
15471550
1548- # If it doesn't support cache, pass the test
1551+ # 1. If it doesn't support cache, skip the test
15491552 if not hasattr (config .get_text_config (), "use_cache" ):
15501553 self .skipTest (reason = f"{ model_class .__name__ } doesn't support caching" )
15511554
15521555 model = model_class (config ).to (torch_device )
1556+ model = model .eval ()
15531557 if "use_cache" not in inputs :
15541558 inputs ["use_cache" ] = True
15551559 outputs = model (** inputs )
15561560
1557- # If "past_key_values" is not returned, pass the test (e.g. RWKV uses a different cache name and format)
15581561 if "past_key_values" not in outputs :
15591562 self .skipTest (reason = "This model doesn't return `past_key_values`" )
15601563
1564+ # 2. retrieve the KV cache and compute its default expected shapes (if no custom shapes are provided)
1565+ past_kv = outputs ["past_key_values" ]
1566+ is_legacy_cache = not isinstance (past_kv , Cache )
1567+
15611568 text_config = config .get_text_config ()
1562- num_hidden_layers = (
1569+ num_decoder_layers = (
15631570 getattr (text_config , "decoder_layers" , None )
15641571 or getattr (text_config , "num_decoder_layers" , None )
15651572 or text_config .num_hidden_layers
15661573 )
1567- num_attention_heads = getattr (text_config , "decoder_attention_heads" , text_config .num_attention_heads )
1568- embed_dim = getattr (text_config , "d_model" , text_config .hidden_size )
1569- per_head_embed_dim = embed_dim // num_attention_heads
1570-
1571- # some models have different num-head for query vs key/value so we need to assign correct value
1572- # BUT only after `per_head_embed_dim` is set
1573- num_attention_heads = (
1574- text_config .num_key_value_heads
1575- if getattr (text_config , "num_key_value_heads" , None ) is not None
1576- else num_attention_heads
1577- )
15781574
1579- past_kv = outputs ["past_key_values" ]
1580- self .assertEqual (len (past_kv ), num_hidden_layers )
1575+ if custom_all_cache_shapes is None :
1576+ num_query_attention_heads = getattr (
1577+ text_config , "decoder_attention_heads" , text_config .num_attention_heads
1578+ )
1579+ embed_dim = getattr (text_config , "d_model" , text_config .hidden_size )
1580+ per_head_embed_dim = embed_dim // num_query_attention_heads
1581+ num_key_value_heads = (
1582+ text_config .num_key_value_heads
1583+ if getattr (text_config , "num_key_value_heads" , None ) is not None
1584+ else num_query_attention_heads
1585+ )
1586+ if config .is_encoder_decoder :
1587+ encoder_num_attention_heads = (
1588+ text_config .encoder_attention_heads
1589+ if hasattr (text_config , "encoder_attention_heads" )
1590+ else text_config .num_attention_heads
1591+ )
1592+ encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
1593+ batch_size , seq_length = inputs ["decoder_input_ids" ].shape
1594+ # The sequence length for the encoder K V depends on the model. Since it is not manipulated in
1595+ # autoregressive generation, we're keeping the test general and not checking the 3rd dim
1596+ default_cross_attention_shape = (
1597+ batch_size ,
1598+ encoder_num_attention_heads ,
1599+ encoder_per_head_embed_dim ,
1600+ )
1601+ default_self_attention_shape = (batch_size , num_key_value_heads , seq_length , per_head_embed_dim )
1602+ all_cache_shapes = [
1603+ [
1604+ default_self_attention_shape ,
1605+ default_self_attention_shape ,
1606+ default_cross_attention_shape ,
1607+ default_cross_attention_shape ,
1608+ ]
1609+ for _ in range (num_decoder_layers )
1610+ ]
1611+ else :
1612+ batch_size , seq_length = inputs ["input_ids" ].shape
1613+ default_self_attention_shape = (batch_size , num_key_value_heads , seq_length , per_head_embed_dim )
1614+ all_cache_shapes = [
1615+ [default_self_attention_shape , default_self_attention_shape ] for _ in range (num_decoder_layers )
1616+ ]
15811617
1582- # Encoder-Decoder checks
1618+ else :
1619+ all_cache_shapes = custom_all_cache_shapes
1620+
1621+ # 3. Check cache shapes
1622+ # 3.1. Encoder-Decoder checks
15831623 if config .is_encoder_decoder :
1584- # encoder-decoder models usually don't have text config
1585- # below is needed only for Pix2Struct which we cannot modify now due to BC
1586- config = config .get_text_config ()
1587- encoder_num_attention_heads = (
1588- config .encoder_attention_heads
1589- if hasattr (config , "encoder_attention_heads" )
1590- else config .num_attention_heads
1624+ num_cache_decoder_layers = (
1625+ len (past_kv ) if is_legacy_cache else len (past_kv .self_attention_cache .key_cache )
15911626 )
1592- encoder_per_head_embed_dim = embed_dim // encoder_num_attention_heads
1593- batch_size , seq_length = inputs ["decoder_input_ids" ].shape
1594- for i in range (num_hidden_layers ):
1595- self .assertEqual (len (past_kv [i ]), 4 ) # K V for the decoder + K V for the encoder = 4
1596- self .assertEqual (
1597- past_kv [i ][0 ].shape , (batch_size , num_attention_heads , seq_length , per_head_embed_dim )
1627+ self .assertEqual (num_cache_decoder_layers , num_decoder_layers )
1628+
1629+ for i in range (num_decoder_layers ):
1630+ if is_legacy_cache :
1631+ self .assertEqual (len (past_kv [0 ]), 4 ) # legacy check: confirm number of elements in tuple
1632+
1633+ # Self attention
1634+ self_attention_layer_key_cache = (
1635+ past_kv [i ][0 ] if is_legacy_cache else past_kv .self_attention_cache .key_cache [i ]
15981636 )
1599- self . assertEqual (
1600- past_kv [i ][1 ]. shape , ( batch_size , num_attention_heads , seq_length , per_head_embed_dim )
1637+ self_attention_layer_value_cache = (
1638+ past_kv [i ][1 ] if is_legacy_cache else past_kv . self_attention_cache . value_cache [ i ]
16011639 )
1602- # The sequence length for the encoder K V depends on the model. Since it is not manipulated in
1603- # autoregressive generation, I'm keeping the test general and not checking the 3rd dim
1604- self .assertEqual (
1605- (past_kv [i ][2 ].shape [0 ], past_kv [i ][2 ].shape [1 ], past_kv [i ][2 ].shape [3 ]),
1606- (batch_size , encoder_num_attention_heads , encoder_per_head_embed_dim ),
1640+ self .assertEqual (self_attention_layer_key_cache .shape , all_cache_shapes [i ][0 ])
1641+ self .assertEqual (self_attention_layer_value_cache .shape , all_cache_shapes [i ][1 ])
1642+
1643+ # Cross attention (ignore 3rd dim, see default shape preparation)
1644+ cross_attention_layer_key_cache = (
1645+ past_kv [i ][2 ] if is_legacy_cache else past_kv .cross_attention_cache .key_cache [i ]
16071646 )
1608- self .assertEqual (
1609- (past_kv [i ][3 ].shape [0 ], past_kv [i ][3 ].shape [1 ], past_kv [i ][3 ].shape [3 ]),
1610- (batch_size , encoder_num_attention_heads , encoder_per_head_embed_dim ),
1647+ cross_attention_layer_value_cache = (
1648+ past_kv [i ][3 ] if is_legacy_cache else past_kv .cross_attention_cache .value_cache [i ]
16111649 )
1650+ cross_attention_layer_key_cache = cross_attention_layer_key_cache [:, :, 0 , :]
1651+ cross_attention_layer_value_cache = cross_attention_layer_value_cache [:, :, 0 , :]
1652+ self .assertEqual (cross_attention_layer_key_cache .shape , all_cache_shapes [i ][2 ])
1653+ self .assertEqual (cross_attention_layer_value_cache .shape , all_cache_shapes [i ][3 ])
16121654
1613- # Decoder-only checks
1655+ # 3.2. Decoder-only checks
16141656 else :
1615- # TODO: this line is only needed because of imagegpt, where "pixel_values" = "input_ids". Fix the
1616- # tests in imagegpt such that `prepare_config_and_inputs_for_common` returns the later (and the other
1617- # tests use it)
1618- key = "input_ids" if "input_ids" in inputs else "pixel_values"
1619- batch_size , seq_length = inputs [key ].shape
1620- for i in range (num_hidden_layers ):
1621- self .assertEqual (len (past_kv [0 ]), 2 ) # K V for the decoder = 2
1622- self .assertEqual (
1623- past_kv [i ][0 ].shape , (batch_size , num_attention_heads , seq_length , per_head_embed_dim )
1624- )
1625- self .assertEqual (
1626- past_kv [i ][1 ].shape , (batch_size , num_attention_heads , seq_length , per_head_embed_dim )
1627- )
1657+ num_cache_decoder_layers = len (past_kv ) if is_legacy_cache else len (past_kv .key_cache )
1658+ self .assertEqual (num_cache_decoder_layers , num_decoder_layers )
1659+
1660+ for i in range (num_decoder_layers ):
1661+ if is_legacy_cache :
1662+ self .assertEqual (len (past_kv [0 ]), 2 ) # legacy check: confirm number of elements in tuple
1663+
1664+ # Self attention
1665+ self_attention_layer_key_cache = past_kv [i ][0 ] if is_legacy_cache else past_kv .key_cache [i ]
1666+ self_attention_layer_value_cache = past_kv [i ][1 ] if is_legacy_cache else past_kv .value_cache [i ]
1667+ self .assertEqual (self_attention_layer_key_cache .shape , all_cache_shapes [i ][0 ])
1668+ self .assertEqual (self_attention_layer_value_cache .shape , all_cache_shapes [i ][1 ])
16281669
16291670 @pytest .mark .generate
16301671 @parameterized .expand ([("greedy" , 1 ), ("beam search" , 2 )])
0 commit comments