@@ -3020,10 +3020,122 @@ def aten_imag(self: TensorType) -> TensorType:
30203020 raise NotImplementedError ()
30213021
30223022
3023- def aten_index (self : TensorType , indices : Optional [Sequence [TensorType ]]) -> TensorType :
3024- """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
3023+ def _are_consecutive (sorted_list : Sequence [int ]) -> bool :
3024+ """Returns True if a sorted list contains consecutive numbers."""
3025+ if not sorted_list :
3026+ return True
30253027
3026- raise NotImplementedError ()
3028+ return sorted_list == list (range (min (sorted_list ), max (sorted_list ) + 1 ))
3029+
3030+
3031+ def _has_none_in_middle (indices ) -> bool :
3032+ """Returns True if there is a None in the middle of the list."""
3033+ not_none_indices = [i for i , idx in enumerate (indices ) if idx is not None ]
3034+ return not _are_consecutive (not_none_indices )
3035+
3036+
3037+ def _shape_of_broadcast_tensors (* args : TensorType ) -> INT64 :
3038+ """Returns the broadcasted shape of the given tensors."""
3039+ broadcasted = op .Max (* args )
3040+ return op .Shape (broadcasted )
3041+
3042+
3043+ @torch_op ("aten::index" , trace_only = True )
3044+ def aten_index (self : TensorType , indices : Sequence [Optional [INT64 ]]) -> TensorType :
3045+ """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
3046+
3047+ NOTE: Understanding `aten::index`
3048+ For `arg0` with shape `[7, 3, 4, 5, 6]`
3049+ The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
3050+
3051+ ```
3052+ +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
3053+ +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
3054+ +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
3055+ +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
3056+ ```
3057+
3058+ Here,
3059+ - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
3060+ - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
3061+
3062+ None in `indices` are like fillers for dimensions that cannot be removed in the process.
3063+ """
3064+
3065+ self_rank = len (self .shape )
3066+ index_ranks = [len (index .shape ) for index in indices if index is not None ]
3067+ advanced_indexing_rank = max (index_ranks )
3068+
3069+ # reordered_positions is the permutation of the index positions where
3070+ # positions with None are move to the end of the list
3071+ # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
3072+ reordered_positions = sorted (range (len (indices )), key = lambda i : (indices [i ] is None , i ))
3073+ # Fill the list with the remaining indices up to the rank of the tensor self.
3074+ # For example, if indices = [None, 1, None, 2], and the rank of self is 6,
3075+ # then reordered_positions = [1, 3, 0, 2, 4, 5]
3076+ reordered_positions = [
3077+ * reordered_positions ,
3078+ * range (len (reordered_positions ), self_rank ),
3079+ ]
3080+ # Transpose self according to the reordered positions
3081+ self = op .Transpose (self , perm = reordered_positions )
3082+
3083+ # Broadcast the indices to the same shape then concatenate
3084+ not_none_indices = [idx for idx in indices if idx is not None ]
3085+ broadcast_shape = _shape_of_broadcast_tensors (* not_none_indices )
3086+ final_index = op .Concat (
3087+ * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), - 1 ) for idx in not_none_indices ),
3088+ axis = - 1 ,
3089+ )
3090+
3091+ self = op .GatherND (self , final_index , batch_dims = 0 )
3092+
3093+ if _has_none_in_middle (indices ):
3094+ # If there is None in the middle, Advanced Indexing cannot decide where to put
3095+ # the new dimensions. So it places them in the front, like GatherND does.
3096+ return self
3097+
3098+ # When the indices are consecutive, Advanced Indexing will place the new dimensions
3099+ # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
3100+ #
3101+ # Input index axes (three parts):
3102+ # [
3103+ # x_None_front_1, ... x_None_front_m,
3104+ # x1, ..., xk,
3105+ # x_None_back_1, ..., x_None_back_m
3106+ # ]
3107+ # GatherND result axes:
3108+ # [
3109+ # *broadcasted_shape(x1, x2, ..., xk),
3110+ # x_None_front_1, ... x_None_front_m,
3111+ # x_None_back_1, ..., x_None_back_m
3112+ # ]
3113+ # (Transpose here)
3114+ # Advanced indexing result axes:
3115+ # [
3116+ # x_None_front_1, ... x_None_front_m,
3117+ # *brocasted_shape(x1, x2, ..., xk),
3118+ # x_None_back_1, ..., x_None_back_m
3119+ # ]
3120+ #
3121+ # Need to transpose the result of GatherND to match this axes ordering.
3122+ first_not_none_position = reordered_positions [0 ] # x_None_front_m + 1
3123+ starting_position_of_none_in_back = (
3124+ advanced_indexing_rank + first_not_none_position
3125+ ) # x_None_back_1
3126+ result_rank = self_rank - len (not_none_indices ) + advanced_indexing_rank
3127+ perm = [
3128+ * range (
3129+ advanced_indexing_rank , starting_position_of_none_in_back
3130+ ), # None_front_1...x_None_back_1
3131+ * range (0 , advanced_indexing_rank ), # 0...len(broadcasted_shape)
3132+ * range (
3133+ starting_position_of_none_in_back ,
3134+ result_rank ,
3135+ ), # None_back_1...None_back_m
3136+ ]
3137+
3138+ return op .Transpose (self , perm = perm )
30273139
30283140
30293141def aten_index_add (
0 commit comments