@@ -3020,10 +3020,124 @@ def aten_imag(self: TensorType) -> TensorType:
30203020 raise NotImplementedError ()
30213021
30223022
3023- def aten_index (self : TensorType , indices : Optional [Sequence [TensorType ]]) -> TensorType :
3024- """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"""
3023+ def _are_consecutive (sorted_list : Sequence [int ]) -> bool :
3024+ """Returns True if a sorted list contains consecutive numbers."""
3025+ if not sorted_list :
3026+ return True
30253027
3026- raise NotImplementedError ()
3028+ return sorted_list == list (range (min (sorted_list ), max (sorted_list ) + 1 ))
3029+
3030+
3031+ def _has_none_in_middle (indices ) -> bool :
3032+ """Returns True if there is a None in the middle of the list."""
3033+ not_none_indices = [i for i , idx in enumerate (indices ) if idx is not None ]
3034+ return not _are_consecutive (not_none_indices )
3035+
3036+
3037+ def _shape_of_broadcast_tensors (* args : TensorType ) -> INT64 :
3038+ """Returns the broadcasted shape of the given tensors."""
3039+ broadcasted = op .Max (* args )
3040+ return op .Shape (broadcasted )
3041+
3042+
3043+ @torch_op ("aten::index" , trace_only = True )
3044+ def aten_index (self : TensorType , indices : Sequence [Optional [INT64 ]]) -> TensorType :
3045+ """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
3046+
3047+ NOTE: Understanding `aten::index`
3048+ For `arg0` with shape `[7, 3, 4, 5, 6]`
3049+ The indexing operation `arg0[0, :, 1:2, tensor([[4,5]])]` will be translated to
3050+
3051+ ```
3052+ +> select: i64[3, 4, 5, 6] = torch.ops.aten.select.int(arg0, 0, 0);
3053+ +> slice_1: i64[3, 4, 5, 6] = torch.ops.aten.slice.Tensor(select, 0, 0, 9223372036854775807);
3054+ +> slice_2: i64[3, 1, 5, 6] = torch.ops.aten.slice.Tensor(slice_1, 1, 1, 2);
3055+ +> index: i64[3, 1, 1, 2, 6] = torch.ops.aten.index.Tensor(slice_2, [None, None, arg1]);
3056+ ```
3057+
3058+ Here,
3059+ - `indices = [None, None, arg1]` is equivalent to `indices = [None, None, arg1, None]`
3060+ - The operation `arg0[0, :, 1:2, tensor([[4,5]])]` is equivalent to `arg0[0, :, 1:2, tensor([[4,5]]), :]`
3061+
3062+ None in `indices` are like fillers for dimensions that cannot be removed in the process.
3063+ """
3064+
3065+ self_rank = len (self .shape )
3066+ index_ranks = [len (index .shape ) for index in indices if index is not None ]
3067+ print ("index_ranks: " , index_ranks )
3068+ print ("indices: " , indices )
3069+ advanced_indexing_rank = max (index_ranks )
3070+
3071+ # reordered_positions is the permutation of the index positions where
3072+ # positions with None are move to the end of the list
3073+ # For example, if indices = [None, 1, None, 2], then reordered_positions = [1, 3, 0, 2]
3074+ reordered_positions = sorted (range (len (indices )), key = lambda i : (indices [i ] is None , i ))
3075+ # Fill the list with the remaining indices up to the rank of the tensor self.
3076+ # For example, if indices = [None, 1, None, 2], and the rank of self is 6,
3077+ # then reordered_positions = [1, 3, 0, 2, 4, 5]
3078+ reordered_positions = [
3079+ * reordered_positions ,
3080+ * range (len (reordered_positions ), self_rank ),
3081+ ]
3082+ # Transpose self according to the reordered positions
3083+ self = op .Transpose (self , perm = reordered_positions )
3084+
3085+ # Broadcast the indices to the same shape then concatenate
3086+ not_none_indices = [idx for idx in indices if idx is not None ]
3087+ broadcast_shape = _shape_of_broadcast_tensors (* not_none_indices )
3088+ final_index = op .Concat (
3089+ * (op .Unsqueeze (op .Expand (idx , broadcast_shape ), - 1 ) for idx in not_none_indices ),
3090+ axis = - 1 ,
3091+ )
3092+
3093+ self = op .GatherND (self , final_index , batch_dims = 0 )
3094+
3095+ if _has_none_in_middle (indices ):
3096+ # If there is None in the middle, Advanced Indexing cannot decide where to put
3097+ # the new dimensions. So it places them in the front, like GatherND does.
3098+ return self
3099+
3100+ # When the indices are consecutive, Advanced Indexing will place the new dimensions
3101+ # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
3102+ #
3103+ # Input index axes (three parts):
3104+ # [
3105+ # x_None_front_1, ... x_None_front_m,
3106+ # x1, ..., xk,
3107+ # x_None_back_1, ..., x_None_back_m
3108+ # ]
3109+ # GatherND result axes:
3110+ # [
3111+ # *broadcasted_shape(x1, x2, ..., xk),
3112+ # x_None_front_1, ... x_None_front_m,
3113+ # x_None_back_1, ..., x_None_back_m
3114+ # ]
3115+ # (Transpose here)
3116+ # Advanced indexing result axes:
3117+ # [
3118+ # x_None_front_1, ... x_None_front_m,
3119+ # *brocasted_shape(x1, x2, ..., xk),
3120+ # x_None_back_1, ..., x_None_back_m
3121+ # ]
3122+ #
3123+ # Need to transpose the result of GatherND to match this axes ordering.
3124+ first_not_none_position = reordered_positions [0 ] # x_None_front_m + 1
3125+ starting_position_of_none_in_back = (
3126+ advanced_indexing_rank + first_not_none_position
3127+ ) # x_None_back_1
3128+ result_rank = self_rank - len (not_none_indices ) + advanced_indexing_rank
3129+ perm = [
3130+ * range (
3131+ advanced_indexing_rank , starting_position_of_none_in_back
3132+ ), # None_front_1...x_None_back_1
3133+ * range (0 , advanced_indexing_rank ), # 0...len(broadcasted_shape)
3134+ * range (
3135+ starting_position_of_none_in_back ,
3136+ result_rank ,
3137+ ), # None_back_1...None_back_m
3138+ ]
3139+
3140+ return op .Transpose (self , perm = perm )
30273141
30283142
30293143def aten_index_add (
0 commit comments