@@ -2700,15 +2700,74 @@ def aten_index_copy(
27002700 raise NotImplementedError ()
27012701
27022702
2703+ @torch_op ("aten::index_put" )
27032704def aten_index_put (
2704- self : TensorType ,
2705- indices : Optional [ Sequence [TensorType ] ],
2706- values : TensorType ,
2705+ self : TReal ,
2706+ indices : Sequence [INT64 ],
2707+ values : TReal ,
27072708 accumulate : bool = False ,
2708- ) -> TensorType :
2709+ ) -> TReal :
27092710 """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
27102711
2711- raise NotImplementedError ()
2712+ index = op .SequenceAt (indices , 0 ) # assume indices only have 1 element
2713+ # change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
2714+ self_dim_1 = op .Gather (op .Shape (self ), 1 )
2715+ index_dim_0 = op .Gather (op .Shape (index ), 0 )
2716+ neg_1 = op .Constant (value_ints = [- 1 ])
2717+ shape = op .Concat (op .Reshape (self_dim_1 , neg_1 ), op .Reshape (index_dim_0 , neg_1 ), axis = 0 )
2718+ new_ind = op .Expand (index , shape )
2719+ new_ind_t = op .Transpose (new_ind )
2720+
2721+ if op .Cast (accumulate , to = BOOL .dtype ):
2722+ # put values into zeros array first, then add to input
2723+ zeros = op .Expand (op .Constant (value_float = 0.0 ), op .Shape (self ))
2724+ result = op .ScatterElements (zeros , new_ind_t , values )
2725+ result = op .Add (result , self )
2726+ else :
2727+ result = op .ScatterElements (self , new_ind_t , values )
2728+ return result
2729+
2730+
2731+ @torch_op ("aten::index_put_bool" , overload = True )
2732+ def aten_index_put_bool (
2733+ self : TReal ,
2734+ indices : Sequence [BOOL ],
2735+ values : TReal ,
2736+ accumulate : bool = False ,
2737+ ) -> TReal :
2738+ """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
2739+
2740+ index = op .SequenceAt (indices , 0 ) # assume indices only have 1 element
2741+ # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
2742+ index_int = op .Cast (index , to = INT32 .dtype )
2743+ # if all False, return self
2744+ if op .ReduceSum (index_int ) == 0 :
2745+ result = self
2746+ else :
2747+ # change array([F,F,T,F,F]) to array([2])
2748+ index = op .ArgMax (index_int ) # assume index only have 1 True
2749+ # change array([2]) to array([2,2,2,2,2])
2750+ self_dim_1 = op .Gather (op .Shape (self ), 1 )
2751+ index_dim_0 = op .Gather (op .Shape (index ), 0 )
2752+ neg_1 = op .Constant (value_ints = [- 1 ])
2753+ shape = op .Concat (
2754+ op .Reshape (self_dim_1 , neg_1 ), op .Reshape (index_dim_0 , neg_1 ), axis = 0
2755+ )
2756+ new_ind = op .Expand (index , shape )
2757+ new_ind_t = op .Transpose (new_ind )
2758+
2759+ # values must have same rank with input(self)
2760+ if op .Size (op .Shape (values )) < op .Size (op .Shape (self )): # type: ignore[operator]
2761+ values = op .Unsqueeze (values , op .Constant (value_ints = [0 ]))
2762+
2763+ if op .Cast (accumulate , to = BOOL .dtype ):
2764+ zeros = op .Expand (op .Constant (value_float = 0.0 ), op .Shape (self ))
2765+ result = op .ScatterElements (zeros , new_ind_t , values )
2766+ result = op .Add (result , self )
2767+ else :
2768+ result = op .ScatterElements (self , new_ind_t , values )
2769+
2770+ return result
27122771
27132772
27142773def aten_index_reduce (
0 commit comments