@@ -3020,10 +3020,122 @@ def aten_imag(self: TensorType) -> TensorType:
3020
3020
raise NotImplementedError ()
3021
3021
3022
3022
3023
- def aten_index (self : TensorType , indices : Optional [Sequence [TensorType ]]) -> TensorType :
3024
- """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
3023
+ def _are_consecutive (sorted_list : Sequence [int ]) -> bool :
3024
+ """Returns True if a sorted list contains consecutive numbers."""
3025
+ if not sorted_list :
3026
+ return True
3025
3027
3026
- raise NotImplementedError ()
3028
+ return sorted_list == list (range (min (sorted_list ), max (sorted_list ) + 1 ))
3029
+
3030
+
3031
+ def _has_none_in_middle (indices ) -> bool :
3032
+ """Returns True if there is a None in the middle of the list."""
3033
+ not_none_indices = [i for i , idx in enumerate (indices ) if idx is not None ]
3034
+ return not _are_consecutive (not_none_indices )
3035
+
3036
+
3037
+ def _shape_of_broadcast_tensors (* args : TensorType ) -> INT64 :
3038
+ """Returns the broadcasted shape of the given tensors."""
3039
+ broadcasted = op .Max (* args )
3040
+ return op .Shape (broadcasted )
3041
+
3042
+
3043
+ @torch_op ("aten::index" , trace_only = True )
3044
+ def aten_index (self : TensorType , indices : Sequence [Optional [INT64 ]]) -> TensorType :
3045
+ """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
3046
+
3047
+ NOTE: Understanding `aten::index`
3048
+ For `arg0` with shape `[7, 3, 4, 5, 6]`
3049
+ The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
3050
+
3051
+ ```
3052
+ +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
3053
+ +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
3054
+ +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
3055
+ +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
3056
+ ```
3057
+
3058
+ Here,
3059
+ - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
3060
+ - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
3061
+
3062
+ None in `indices` are like fillers for dimensions that cannot be removed in the process.
3063
+ """
3064
+
3065
+ self_rank = len (self .shape )
3066
+ index_ranks = [len (index .shape ) for index in indices if index is not None ]
3067
+ advanced_indexing_rank = max (index_ranks )
3068
+
3069
+ # reordered_positions is the permutation of the index positions where
3070
+ # positions with None are move to the end of the list
3071
+ # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
3072
+ reordered_positions = sorted (range (len (indices )), key = lambda i : (indices [i ] is None , i ))
3073
+ # Fill the list with the remaining indices up to the rank of the tensor self.
3074
+ # For example, if indices = [None, 1, None, 2], and the rank of self is 6,
3075
+ # then reordered_positions = [1, 3, 0, 2, 4, 5]
3076
+ reordered_positions = [
3077
+ * reordered_positions ,
3078
+ * range (len (reordered_positions ), self_rank ),
3079
+ ]
3080
+ # Transpose self according to the reordered positions
3081
+ self = op .Transpose (self , perm = reordered_positions )
3082
+
3083
+ # Broadcast the indices to the same shape then concatenate
3084
+ not_none_indices = [idx for idx in indices if idx is not None ]
3085
+ broadcast_shape = _shape_of_broadcast_tensors (* not_none_indices )
3086
+ final_index = op .Concat (
3087
+ * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), - 1 ) for idx in not_none_indices ),
3088
+ axis = - 1 ,
3089
+ )
3090
+
3091
+ self = op .GatherND (self , final_index , batch_dims = 0 )
3092
+
3093
+ if _has_none_in_middle (indices ):
3094
+ # If there is None in the middle, Advanced Indexing cannot decide where to put
3095
+ # the new dimensions. So it places them in the front, like GatherND does.
3096
+ return self
3097
+
3098
+ # When the indices are consecutive, Advanced Indexing will place the new dimensions
3099
+ # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
3100
+ #
3101
+ # Input index axes (three parts):
3102
+ # [
3103
+ # x_None_front_1, ... x_None_front_m,
3104
+ # x1, ..., xk,
3105
+ # x_None_back_1, ..., x_None_back_m
3106
+ # ]
3107
+ # GatherND result axes:
3108
+ # [
3109
+ # *broadcasted_shape(x1, x2, ..., xk),
3110
+ # x_None_front_1, ... x_None_front_m,
3111
+ # x_None_back_1, ..., x_None_back_m
3112
+ # ]
3113
+ # (Transpose here)
3114
+ # Advanced indexing result axes:
3115
+ # [
3116
+ # x_None_front_1, ... x_None_front_m,
3117
+ # *brocasted_shape(x1, x2, ..., xk),
3118
+ # x_None_back_1, ..., x_None_back_m
3119
+ # ]
3120
+ #
3121
+ # Need to transpose the result of GatherND to match this axes ordering.
3122
+ first_not_none_position = reordered_positions [0 ] # x_None_front_m + 1
3123
+ starting_position_of_none_in_back = (
3124
+ advanced_indexing_rank + first_not_none_position
3125
+ ) # x_None_back_1
3126
+ result_rank = self_rank - len (not_none_indices ) + advanced_indexing_rank
3127
+ perm = [
3128
+ * range (
3129
+ advanced_indexing_rank , starting_position_of_none_in_back
3130
+ ), # None_front_1...x_None_back_1
3131
+ * range (0 , advanced_indexing_rank ), # 0...len(broadcasted_shape)
3132
+ * range (
3133
+ starting_position_of_none_in_back ,
3134
+ result_rank ,
3135
+ ), # None_back_1...None_back_m
3136
+ ]
3137
+
3138
+ return op .Transpose (self , perm = perm )
3027
3139
3028
3140
3029
3141
def aten_index_add (
0 commit comments