@@ -4259,15 +4259,48 @@ def aten_index_put(
4259
4259
See implementation of `torch.onnx.symbolic_opset11.index_put
4260
4260
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
4261
4261
"""
4262
+ # Pad indices with None so It has the same rank as self
4263
+ self_rank = len (self .shape )
4264
+ if len (indices ) < self_rank :
4265
+ indices = list (indices ) + [None ] * (self_rank - len (indices ))
4262
4266
4263
- # TODO(justinchuby): Handle when indicies has more than one element
4264
- index = indices [0 ]
4265
- new_index = op .Unsqueeze (index , [- 1 ])
4267
+ values_shape = values .shape .numpy ()
4268
+
4269
+ index_vectors = []
4270
+ for i , index in enumerate (indices ):
4271
+ if index is None :
4272
+ # For a full slice, create a range.
4273
+ index_vector = op .Range (start = 0 , limit = values_shape [i ], delta = 1 )
4274
+ else :
4275
+ index_vector = index
4276
+
4277
+ # Shape vector with 1s, except at axis i.
4278
+ shape_vector = [1 ] * self_rank
4279
+ shape_vector [i ] = values_shape [i ]
4280
+
4281
+ # Reshape index_vector so that only the i-th dimension matches values_shape[i]
4282
+ reshaped_index_vector = op .Reshape (index_vector , shape_vector )
4283
+
4284
+ # Expand reshaped_index_vector to match the full shape of values
4285
+ expanded_index_vector = op .Expand (reshaped_index_vector , values_shape )
4286
+
4287
+ # Flatten into a 1D vector
4288
+ column_index_vector = op .Reshape (expanded_index_vector , [- 1 ])
4289
+
4290
+ # Convert into a column vector to prepare for concatenation
4291
+ column_index_vector = op .Unsqueeze (column_index_vector , axes = [1 ])
4292
+ index_vectors .append (column_index_vector )
4293
+
4294
+ # Contains all indices to be upadated
4295
+ new_index = op .Concat (* index_vectors , axis = 1 )
4296
+
4297
+ # Flatten values to match the indices
4298
+ flat_values = op .Reshape (values , [- 1 ])
4266
4299
4267
4300
if accumulate :
4268
- result = op .ScatterND (self , new_index , values , reduction = "add" )
4301
+ result = op .ScatterND (self , new_index , flat_values , reduction = "add" )
4269
4302
else :
4270
- result = op .ScatterND (self , new_index , values )
4303
+ result = op .ScatterND (self , new_index , flat_values )
4271
4304
4272
4305
return result
4273
4306
0 commit comments