@@ -4058,12 +4058,9 @@ def aten_index_put_bool(
4058
4058
# change array([F,F,T,F,F]) to array([2])
4059
4059
index = op .ArgMax (index_int ) # assume index only have 1 True
4060
4060
# change array([2]) to array([2,2,2,2,2])
4061
- self_dim_1 = op .Gather (op .Shape (self ), 1 )
4062
- index_dim_0 = op .Gather (op .Shape (index ), 0 )
4063
- neg_1 = op .Constant (value_ints = [- 1 ])
4064
- shape = op .Concat (
4065
- op .Reshape (self_dim_1 , neg_1 ), op .Reshape (index_dim_0 , neg_1 ), axis = 0
4066
- )
4061
+ self_dim_1 = op .Shape (self , start = 1 , end = 2 )
4062
+ index_dim_0 = op .Shape (index , start = 0 , end = 1 )
4063
+ shape = op .Concat (self_dim_1 , index_dim_0 , axis = 0 )
4067
4064
new_ind = op .Expand (index , shape )
4068
4065
new_ind_t = op .Transpose (new_ind )
4069
4066
@@ -7512,7 +7509,7 @@ def _center_window_around_zeros_if_needed(
7512
7509
window : TFloatOrBFloat16 , n_fft : int
7513
7510
) -> TFloatOrBFloat16 :
7514
7511
# first dimension
7515
- n_win = op .Gather ( op . Shape (window ), 0 )
7512
+ n_win = op .Shape (window , start = 0 , end = 1 )
7516
7513
# Center window around zeros if needed (required by ONNX's STFT)
7517
7514
if n_win < n_fft :
7518
7515
left = (n_fft - n_win ) / 2
0 commit comments