Skip to content

Commit 858d31a

Browse files
author
AyoubMDL
committed
fix(index_put): handle None cases
1 parent 5c31a7e commit 858d31a

File tree

1 file changed

+38
-5
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+38
-5
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4259,15 +4259,48 @@ def aten_index_put(
42594259
See implementation of `torch.onnx.symbolic_opset11.index_put
42604260
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
42614261
"""
4262+
# Pad indices with None so It has the same rank as self
4263+
self_rank = len(self.shape)
4264+
if len(indices) < self_rank:
4265+
indices = list(indices) + [None] * (self_rank - len(indices))
42624266

4263-
# TODO(justinchuby): Handle when indicies has more than one element
4264-
index = indices[0]
4265-
new_index = op.Unsqueeze(index, [-1])
4267+
values_shape = values.shape.numpy()
4268+
4269+
index_vectors = []
4270+
for i, index in enumerate(indices):
4271+
if index is None:
4272+
# For a full slice, create a range.
4273+
index_vector = op.Range(start=0, limit=values_shape[i], delta=1)
4274+
else:
4275+
index_vector = index
4276+
4277+
# Shape vector with 1s, except at axis i.
4278+
shape_vector = [1] * self_rank
4279+
shape_vector[i] = values_shape[i]
4280+
4281+
# Reshape index_vector so that only the i-th dimension matches values_shape[i]
4282+
reshaped_index_vector = op.Reshape(index_vector, shape_vector)
4283+
4284+
# Expand reshaped_index_vector to match the full shape of values
4285+
expanded_index_vector = op.Expand(reshaped_index_vector, values_shape)
4286+
4287+
# Flatten into a 1D vector
4288+
column_index_vector = op.Reshape(expanded_index_vector, [-1])
4289+
4290+
# Convert into a column vector to prepare for concatenation
4291+
column_index_vector = op.Unsqueeze(column_index_vector, axes=[1])
4292+
index_vectors.append(column_index_vector)
4293+
4294+
# Contains all indices to be upadated
4295+
new_index = op.Concat(*index_vectors, axis=1)
4296+
4297+
# Flatten values to match the indices
4298+
flat_values = op.Reshape(values, [-1])
42664299

42674300
if accumulate:
4268-
result = op.ScatterND(self, new_index, values, reduction="add")
4301+
result = op.ScatterND(self, new_index, flat_values, reduction="add")
42694302
else:
4270-
result = op.ScatterND(self, new_index, values)
4303+
result = op.ScatterND(self, new_index, flat_values)
42714304

42724305
return result
42734306

0 commit comments

Comments
 (0)