@@ -1345,87 +1345,95 @@ def __init__(
13451345 )
13461346 self .sliding_window_size = getattr (normalized_config , "sliding_window" , sequence_length )
13471347
1348- def generate (self , input_name : str , framework : str = "pt" , int_dtype : str = "int64" , float_dtype : str = "fp32" ):
1349- if input_name in ["input_ids" , "token_type_ids" , "position_ids" ]:
1350- return super ().generate (
1351- input_name = input_name , framework = framework , int_dtype = int_dtype , float_dtype = float_dtype
1352- )
1353- if input_name == "attention_mask" :
1354- return {
1355- "full_causal_mask" : self ._generate_full_causal_mask (framework , float_dtype ),
1356- "sliding_causal_mask" : self ._generate_sliding_causal_mask (framework , float_dtype ),
1357- }
1358- # if input_name == "full_causal_mask":
1359- # return self._generate_full_causal_mask(framework, float_dtype)
1360- # elif input_name == "sliding_causal_mask":
1361- # return self._generate_sliding_causal_mask(framework, float_dtype)
1362- else :
1363- raise ValueError (f"What happened? This is not supported and should not be here: { input_name } " )
1364-
1365- def _generate_full_causal_mask (self , framework : str = "pt" , float_dtype : str = "float32" ):
1366- if framework == "pt" :
1367- mask = torch .triu (
1368- torch .ones ((self .sequence_length , self .sequence_length ), dtype = DTYPE_MAPPER .pt (float_dtype )),
1369- diagonal = 1 ,
1370- )
1371- mask = mask .masked_fill (mask == 1 , float ("-inf" ))
1372- mask = mask .unsqueeze (0 ).expand (self .batch_size , - 1 , - 1 )
1373- return mask
1374- elif framework == "tf" :
1375- mask = tf .linalg .band_part (
1376- tf .ones ((self .sequence_length , self .sequence_length ), dtype = DTYPE_MAPPER .tf (float_dtype )), - 1 , 0
1377- )
1378- mask = tf .where (mask == 0 , float ("-inf" ), 0.0 )
1379- mask = tf .expand_dims (mask , 0 )
1380- mask = tf .tile (mask , [self .batch_size , 1 , 1 ])
1381- return mask
1382- else :
1383- mask = np .triu (
1384- np .ones ((self .sequence_length , self .sequence_length ), dtype = DTYPE_MAPPER .np (float_dtype )), k = 1
1385- )
1386- mask = np .where (mask == 1 , float ("-inf" ), 0.0 )
1387- mask = np .expand_dims (mask , 0 )
1388- mask = np .tile (mask , (self .batch_size , 1 , 1 ))
1389- return mask
1390-
1391- def _generate_sliding_causal_mask (self , framework : str = "pt" , float_dtype : str = "fp32" ):
1392- if framework == "pt" :
1393- mask = torch .full (
1394- (self .sequence_length , self .sequence_length ), float ("-inf" ), dtype = DTYPE_MAPPER .pt (float_dtype )
1395- )
1396- for i in range (self .sequence_length ):
1397- start = max (0 , i - self .sliding_window_size + 1 )
1398- mask [i , start : i + 1 ] = 0.0
1399- mask = mask .unsqueeze (0 ).expand (self .batch_size , - 1 , - 1 )
1400- return mask
1401- elif framework == "tf" :
1402- mask = tf .fill ((self .sequence_length , self .sequence_length ), float ("-inf" ))
1403- mask = tf .cast (mask , DTYPE_MAPPER .tf (float_dtype ))
1404-
1405- updates = []
1406- indices = []
1407- for i in range (self .sequence_length ):
1408- start = max (0 , i - self .sliding_window_size + 1 )
1409- for j in range (start , i + 1 ):
1410- indices .append ([i , j ])
1411- updates .append (0.0 )
1412- if indices :
1413- indices = tf .constant (indices )
1414- updates = tf .constant (updates , dtype = DTYPE_MAPPER .tf (float_dtype ))
1415- mask = tf .tensor_scatter_nd_update (mask , indices , updates )
1416- mask = tf .expand_dims (mask , 0 )
1417- mask = tf .tile (mask , [self .batch_size , 1 , 1 ])
1418- return mask
1419- else :
1420- mask = np .full (
1421- (self .sequence_length , self .sequence_length ), float ("-inf" ), dtype = DTYPE_MAPPER .np (float_dtype )
1422- )
1423- for i in range (self .sequence_length ):
1424- start = max (0 , i - self .sliding_window_size + 1 )
1425- mask [i , start : i + 1 ] = 0.0
1426- mask = np .expand_dims (mask , 0 )
1427- mask = np .tile (mask , (self .batch_size , 1 , 1 ))
1428- return mask
1348+ # def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
1349+ # if input_name in ["input_ids", "token_type_ids", "position_ids"]:
1350+ # return super().generate(
1351+ # input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
1352+ # )
1353+ # if input_name == "attention_mask":
1354+ # return {
1355+ # "full_attention": self._generate_full_causal_mask(framework, float_dtype),
1356+ # "sliding_attention": self._generate_sliding_causal_mask(framework, float_dtype),
1357+ # }
1358+ # # if input_name == "full_causal_mask":
1359+ # # return self._generate_full_causal_mask(framework, float_dtype)
1360+ # # elif input_name == "sliding_causal_mask":
1361+ # # return self._generate_sliding_causal_mask(framework, float_dtype)
1362+ # else:
1363+ # raise ValueError(f"What happened? This is not supported and should not be here: {input_name}")
1364+
1365+ # def _generate_full_causal_mask(self, framework: str = "pt", float_dtype: str = "float32"):
1366+ # if framework == "pt":
1367+ # row_indices = torch.arange(self.sequence_length).view(-1, 1)
1368+ # col_indices = torch.arange(self.sequence_length).view(1, -1)
1369+ # causal_mask = row_indices >= col_indices
1370+ # dtype = getattr(torch, float_dtype)
1371+ # mask = torch.zeros((self.sequence_length, self.sequence_length), dtype=dtype)
1372+ # mask[~causal_mask] = float("-inf")
1373+ # mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1374+ # return mask
1375+ # elif framework == "tf":
1376+ # row_indices, col_indices = tf.meshgrid(
1377+ # tf.range(self.sequence_length), tf.range(self.sequence_length), indexing="ij"
1378+ # )
1379+ # causal_mask = row_indices >= col_indices
1380+ # dtype = getattr(tf, float_dtype)
1381+ # mask = tf.where(
1382+ # causal_mask,
1383+ # tf.zeros((self.sequence_length, self.sequence_length), dtype=dtype),
1384+ # tf.fill((self.sequence_length, self.sequence_length), float("-inf")),
1385+ # )
1386+ # mask = tf.expand_dims(mask, 0)
1387+ # mask = tf.tile(mask, [self.batch_size, 1, 1])
1388+ # return mask
1389+
1390+ # else:
1391+ # row_indices = np.arange(self.sequence_length).reshape(-1, 1)
1392+ # col_indices = np.arange(self.sequence_length).reshape(1, -1)
1393+ # causal_mask = row_indices >= col_indices
1394+ # dtype = getattr(np, float_dtype)
1395+ # mask = np.full((self.sequence_length, self.sequence_length), float("-inf"), dtype=dtype)
1396+ # mask[causal_mask] = 0.0
1397+ # mask = np.expand_dims(mask, 0)
1398+ # mask = np.repeat(mask, self.batch_size, axis=0)
1399+ # return mask
1400+
1401+ # def _generate_sliding_causal_mask(self, window_size: int, framework: str = "pt", float_dtype: str = "float32"):
1402+ # if framework == "pt":
1403+ # row_indices = torch.arange(self.sequence_length).view(-1, 1)
1404+ # col_indices = torch.arange(self.sequence_length).view(1, -1)
1405+ # causal_mask = (row_indices >= col_indices) & (row_indices - col_indices < window_size)
1406+ # dtype = getattr(torch, float_dtype)
1407+ # mask = torch.zeros((self.sequence_length, self.sequence_length), dtype=dtype)
1408+ # mask[~causal_mask] = float("-inf")
1409+ # mask = mask.unsqueeze(0).expand(self.batch_size, -1, -1)
1410+ # return mask
1411+ # elif framework == "tf":
1412+ # row_indices, col_indices = tf.meshgrid(
1413+ # tf.range(self.sequence_length), tf.range(self.sequence_length), indexing="ij"
1414+ # )
1415+ # causal_condition = row_indices >= col_indices
1416+ # window_condition = (row_indices - col_indices) < window_size
1417+ # sliding_mask = causal_condition & window_condition
1418+ # dtype = getattr(tf, float_dtype)
1419+ # mask = tf.where(
1420+ # sliding_mask,
1421+ # tf.zeros((self.sequence_length, self.sequence_length), dtype=dtype),
1422+ # tf.fill((self.sequence_length, self.sequence_length), float("-inf")),
1423+ # )
1424+ # mask = tf.expand_dims(mask, 0)
1425+ # mask = tf.tile(mask, [self.batch_size, 1, 1])
1426+ # return mask
1427+ # else:
1428+ # row_indices = np.arange(self.sequence_length).reshape(-1, 1)
1429+ # col_indices = np.arange(self.sequence_length).reshape(1, -1)
1430+ # causal_mask = (row_indices >= col_indices) & (row_indices - col_indices < window_size)
1431+ # dtype = getattr(np, float_dtype)
1432+ # mask = np.full((self.sequence_length, self.sequence_length), float("-inf"), dtype=dtype)
1433+ # mask[causal_mask] = 0.0
1434+ # mask = np.expand_dims(mask, 0)
1435+ # mask = np.repeat(mask, self.batch_size, axis=0)
1436+ # return mask
14291437
14301438
14311439class DummySpeechT5InputGenerator (DummyInputGenerator ):
0 commit comments