@@ -1647,7 +1647,8 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1647
1647
dim_size = op .Gather (self_shape , dim , axis = 0 )
1648
1648
# Compute size/chunk to get the number of data in one chunk
1649
1649
num_per_chunk = op .Div (dim_size , chunks )
1650
- num_per_chunk = op .Cast (op .Mod (dim_size , chunks ) > 0 , to = INT64 .dtype ) + num_per_chunk # type: ignore[operator]
1650
+ num_per_chunk = op .Cast (op .Mod (dim_size , chunks ) > 0 , to = INT64 .dtype ) + \
1651
+ num_per_chunk # type: ignore[operator]
1651
1652
1652
1653
# Compute real chunk number
1653
1654
num_chunk = op .Div (dim_size , num_per_chunk )
@@ -4259,15 +4260,51 @@ def aten_index_put(
4259
4260
See implementation of `torch.onnx.symbolic_opset11.index_put
4260
4261
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
4261
4262
"""
4263
+ # Pad indices with None so It has the same rank as self
4264
+ self_rank = len (self .shape )
4265
+ if len (indices ) < self_rank :
4266
+ indices = list (indices ) + [None ] * (self_rank - len (indices ))
4267
+
4268
+ values_shape = values .shape .numpy ()
4269
+ # Pad values_shape with 1 so It has the same rank as self
4270
+ if len (values_shape ) < self_rank :
4271
+ values_shape = (1 ,) * (self_rank - len (values_shape )) + values_shape
4272
+
4273
+ index_vectors = []
4274
+ for i , index in enumerate (indices ):
4275
+ if index is None :
4276
+ # For a full slice, create a range.
4277
+ index_vector = op .Range (start = 0 , limit = values_shape [i ], delta = 1 )
4278
+ else :
4279
+ index_vector = index
4262
4280
4263
- # TODO(justinchuby): Handle when indicies has more than one element
4264
- index = indices [0 ]
4265
- new_index = op .Unsqueeze (index , [- 1 ])
4281
+ # Shape vector with 1s, except at axis i.
4282
+ shape_vector = [1 ] * self_rank
4283
+ shape_vector [i ] = values_shape [i ]
4284
+
4285
+ # Reshape index_vector so that only the i-th dimension matches values_shape[i]
4286
+ reshaped_index_vector = op .Reshape (index_vector , shape_vector )
4287
+
4288
+ # Expand reshaped_index_vector to match the full shape of values
4289
+ expanded_index_vector = op .Expand (reshaped_index_vector , values_shape )
4290
+
4291
+ # Flatten into a 1D vector
4292
+ column_index_vector = op .Reshape (expanded_index_vector , [- 1 ])
4293
+
4294
+ # Convert into a column vector to prepare for concatenation
4295
+ column_index_vector = op .Unsqueeze (column_index_vector , axes = [1 ])
4296
+ index_vectors .append (column_index_vector )
4297
+
4298
+ # Contains all indices to be upadated
4299
+ new_index = op .Concat (* index_vectors , axis = 1 )
4300
+
4301
+ # Flatten values to match the indices
4302
+ flat_values = op .Reshape (values , [- 1 ])
4266
4303
4267
4304
if accumulate :
4268
- result = op .ScatterND (self , new_index , values , reduction = "add" )
4305
+ result = op .ScatterND (self , new_index , flat_values , reduction = "add" )
4269
4306
else :
4270
- result = op .ScatterND (self , new_index , values )
4307
+ result = op .ScatterND (self , new_index , flat_values )
4271
4308
4272
4309
return result
4273
4310
0 commit comments