@@ -1128,27 +1128,41 @@ def __init__(
11281128 self .multi_query = normalized_config .multi_query
11291129
11301130 def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1131- if self .multi_query :
1132- past_key_value_shape = (
1133- self .batch_size ,
1134- self .sequence_length ,
1135- self .hidden_size // self .num_attention_heads * 2 ,
1136- )
1137- return [
1138- self .random_float_tensor (past_key_value_shape , framework = framework , dtype = float_dtype )
1139- for _ in range (self .num_layers )
1131+ if is_transformers_version ("<" , "4.54" ):
1132+ if self .multi_query :
1133+ shape = (
1134+ self .batch_size ,
1135+ self .sequence_length ,
1136+ self .hidden_size // self .num_attention_heads * 2 ,
1137+ )
1138+ else :
1139+ shape = (
1140+ self .batch_size ,
1141+ self .num_attention_heads ,
1142+ self .sequence_length ,
1143+ self .hidden_size // self .num_attention_heads * 2 ,
1144+ )
1145+ pkv = [
1146+ self .random_float_tensor (shape , framework = framework , dtype = float_dtype ) for _ in range (self .num_layers )
11401147 ]
1148+
11411149 else :
11421150 shape = (
11431151 self .batch_size ,
1144- self .num_attention_heads ,
1152+ self .num_attention_heads if not self . multi_query else 1 ,
11451153 self .sequence_length ,
1146- self .hidden_size // self .num_attention_heads * 2 ,
1154+ self .hidden_size // self .num_attention_heads ,
11471155 )
1148- return [
1149- self .random_float_tensor (shape , framework = framework , dtype = float_dtype ) for _ in range (self .num_layers )
1156+ pkv = [
1157+ (
1158+ self .random_float_tensor (shape , framework = framework , dtype = float_dtype ),
1159+ self .random_float_tensor (shape , framework = framework , dtype = float_dtype ),
1160+ )
1161+ for _ in range (self .num_layers )
11501162 ]
11511163
1164+ return pkv
1165+
11521166
11531167class BloomDummyPastKeyValuesGenerator (DummyPastKeyValuesGenerator ):
11541168 def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
@@ -1278,30 +1292,24 @@ def __init__(
12781292 random_sequence_length_range = random_sequence_length_range ,
12791293 ** kwargs ,
12801294 )
1281- self .num_kv_heads = self . num_kv_heads = (
1295+ self .num_kv_heads = (
12821296 normalized_config .num_kv_heads
12831297 if (normalized_config .new_decoder_architecture or not normalized_config .multi_query )
12841298 else 1
12851299 )
12861300 self .head_dim = self .hidden_size // self .num_attention_heads
12871301
12881302 def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1289- past_key_shape = (
1290- self .batch_size ,
1291- self .num_kv_heads ,
1292- self .sequence_length ,
1293- self .head_dim ,
1294- )
1295- past_value_shape = (
1303+ shape = (
12961304 self .batch_size ,
12971305 self .num_kv_heads ,
12981306 self .sequence_length ,
12991307 self .head_dim ,
13001308 )
13011309 return [
13021310 (
1303- self .random_float_tensor (past_key_shape , framework = framework , dtype = float_dtype ),
1304- self .random_float_tensor (past_value_shape , framework = framework , dtype = float_dtype ),
1311+ self .random_float_tensor (shape , framework = framework , dtype = float_dtype ),
1312+ self .random_float_tensor (shape , framework = framework , dtype = float_dtype ),
13051313 )
13061314 for _ in range (self .num_layers )
13071315 ]
0 commit comments