Skip to content

Commit f1f6ae4

Browse files
Fix gpt_bigcode input generator for transformers 4.54 (#2336)
* Fix gpt_bigcode input generator for transformers 4.54 * style --------- Co-authored-by: IlyasMoutawwakil <moutawwakil.ilyas.tsi@gmail.com>
1 parent 5a63cee commit f1f6ae4

1 file changed

Lines changed: 31 additions & 23 deletions

File tree

optimum/utils/input_generators.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,27 +1128,41 @@ def __init__(
11281128
self.multi_query = normalized_config.multi_query
11291129

11301130
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1131-
if self.multi_query:
1132-
past_key_value_shape = (
1133-
self.batch_size,
1134-
self.sequence_length,
1135-
self.hidden_size // self.num_attention_heads * 2,
1136-
)
1137-
return [
1138-
self.random_float_tensor(past_key_value_shape, framework=framework, dtype=float_dtype)
1139-
for _ in range(self.num_layers)
1131+
if is_transformers_version("<", "4.54"):
1132+
if self.multi_query:
1133+
shape = (
1134+
self.batch_size,
1135+
self.sequence_length,
1136+
self.hidden_size // self.num_attention_heads * 2,
1137+
)
1138+
else:
1139+
shape = (
1140+
self.batch_size,
1141+
self.num_attention_heads,
1142+
self.sequence_length,
1143+
self.hidden_size // self.num_attention_heads * 2,
1144+
)
1145+
pkv = [
1146+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype) for _ in range(self.num_layers)
11401147
]
1148+
11411149
else:
11421150
shape = (
11431151
self.batch_size,
1144-
self.num_attention_heads,
1152+
self.num_attention_heads if not self.multi_query else 1,
11451153
self.sequence_length,
1146-
self.hidden_size // self.num_attention_heads * 2,
1154+
self.hidden_size // self.num_attention_heads,
11471155
)
1148-
return [
1149-
self.random_float_tensor(shape, framework=framework, dtype=float_dtype) for _ in range(self.num_layers)
1156+
pkv = [
1157+
(
1158+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
1159+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
1160+
)
1161+
for _ in range(self.num_layers)
11501162
]
11511163

1164+
return pkv
1165+
11521166

11531167
class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
11541168
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
@@ -1278,30 +1292,24 @@ def __init__(
12781292
random_sequence_length_range=random_sequence_length_range,
12791293
**kwargs,
12801294
)
1281-
self.num_kv_heads = self.num_kv_heads = (
1295+
self.num_kv_heads = (
12821296
normalized_config.num_kv_heads
12831297
if (normalized_config.new_decoder_architecture or not normalized_config.multi_query)
12841298
else 1
12851299
)
12861300
self.head_dim = self.hidden_size // self.num_attention_heads
12871301

12881302
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1289-
past_key_shape = (
1290-
self.batch_size,
1291-
self.num_kv_heads,
1292-
self.sequence_length,
1293-
self.head_dim,
1294-
)
1295-
past_value_shape = (
1303+
shape = (
12961304
self.batch_size,
12971305
self.num_kv_heads,
12981306
self.sequence_length,
12991307
self.head_dim,
13001308
)
13011309
return [
13021310
(
1303-
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
1304-
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
1311+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
1312+
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
13051313
)
13061314
for _ in range(self.num_layers)
13071315
]

0 commit comments

Comments
 (0)