@@ -2991,8 +2991,8 @@ def _aten_embedding_bag_onnx(
2991
2991
indices_1d = op .Reshape (indices , neg_1 )
2992
2992
# Get weight out according to indices_1d,
2993
2993
new_weight = op .Gather (weight , indices_1d )
2994
- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995
- new_weight = op .Mul (new_weight , op .Unsqueeze (per_sample_weights , axes = 1 ))
2994
+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
2995
+ new_weight = op .Mul (new_weight , op .Unsqueeze (per_sample_weights , axes = [ 1 ] ))
2996
2996
weight_dim_1 = op .Reshape (op .Shape (weight , start = 1 ), neg_1 )
2997
2997
indices_size = op .Shape (indices_1d )
2998
2998
@@ -3131,8 +3131,8 @@ def _aten_embedding_bag_1d_padding_idx_onnx(
3131
3131
# Get weight out according to indices,
3132
3132
# e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]]
3133
3133
indices_weight = op .Gather (weight , indices )
3134
- # This happends after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135
- indices_weight = op .Mul (indices_weight , op .Unsqueeze (per_sample_weights , axes = 1 ))
3134
+ # This happens after first step of Gather. Because Shape(indices)==Shape(per_sample_weights)
3135
+ indices_weight = op .Mul (indices_weight , op .Unsqueeze (per_sample_weights , axes = [ 1 ] ))
3136
3136
3137
3137
# The element in sequence must be FLOAT32 dtype due to ORT bug
3138
3138
indices_weight = op .Cast (indices_weight , to = FLOAT .dtype )
@@ -4145,7 +4145,6 @@ def _shape_of_broadcast_tensors(*args: TensorType) -> INT64:
4145
4145
return op .Shape (broadcasted )
4146
4146
4147
4147
4148
- @torch_op ("aten::index.Tensor" , private = True , trace_only = True )
4149
4148
def _aten_index_onnx (
4150
4149
self : TensorType ,
4151
4150
indices : Sequence [Optional [INT64 ]],
@@ -4173,7 +4172,7 @@ def _aten_index_onnx(
4173
4172
not_none_indices = [idx for idx in indices if idx is not None ]
4174
4173
broadcast_shape = _shape_of_broadcast_tensors (* not_none_indices )
4175
4174
final_index = op .Concat (
4176
- * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), - 1 ) for idx in not_none_indices ),
4175
+ * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), [ - 1 ] ) for idx in not_none_indices ),
4177
4176
axis = - 1 ,
4178
4177
)
4179
4178
@@ -7706,13 +7705,13 @@ def aten_select_backward(
7706
7705
raise NotImplementedError ()
7707
7706
7708
7707
7709
- @torch_op ("aten::select_scatter" )
7708
+ @torch_op ("aten::select_scatter" , trace_only = True )
7710
7709
def aten_select_scatter (self : TensorType , src : TensorType , dim : int , index : int ) -> TensorType :
7711
7710
"""select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor"""
7712
7711
7713
7712
# Change src rank to self rank according to dim
7714
7713
# e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4]
7715
- update = op .Unsqueeze (src , axes = dim )
7714
+ update = op .Unsqueeze (src , axes = [ dim ] )
7716
7715
# Change index rank to the same as 'update' [2,1,4]
7717
7716
indices = op .Expand (index , op .Shape (update ))
7718
7717
return op .ScatterElements (self , indices , update , axis = dim , reduction = "none" )
@@ -7880,7 +7879,7 @@ def aten_slice_scatter(
7880
7879
zero ,
7881
7880
op .Unsqueeze (step , zero ),
7882
7881
)
7883
- index_base = op .Unsqueeze (index_base , - 1 )
7882
+ index_base = op .Unsqueeze (index_base , [ - 1 ] )
7884
7883
7885
7884
# Use trace only to construct the perm attribute in Transpose
7886
7885
dims = None
@@ -8623,7 +8622,7 @@ def aten_unfold(self: TTensor, dimension: int, size: int, step: int) -> TTensor:
8623
8622
8624
8623
self_rank = len (self .shape )
8625
8624
if self_rank == 0 :
8626
- result = op .Unsqueeze (self , 0 )
8625
+ result = op .Unsqueeze (self , [ 0 ] )
8627
8626
else :
8628
8627
# Handle negative dimension
8629
8628
if dimension < 0 :
@@ -8792,8 +8791,7 @@ def aten_unsafe_split_with_sizes(
8792
8791
def aten_unsqueeze (self : TTensor , dim : int ) -> TTensor :
8793
8792
"""unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"""
8794
8793
8795
- dim = op .Cast (dim , to = INT64 .dtype )
8796
- return op .Unsqueeze (self , dim )
8794
+ return op .Unsqueeze (self , [dim ])
8797
8795
8798
8796
8799
8797
def aten_unsqueeze_copy (self : TensorType , dim : int ) -> TensorType :
0 commit comments