Skip to content

Commit 8e53070

Browse files
AyoubMDLAyoubMDLjustinchuby
authored
[torchlib] Fix index_put: handle None cases (#2061)
This PR introduces support for `None` indices in the `index_put` function. If an index is None, it acts as a full slice (`:`). ### index_put Logic: 1. Construct index grid that contains all the indices to be updated 2. Reshapes the update values to match the computed indices. --------- Co-authored-by: AyoubMDL <[email protected]> Co-authored-by: Justin Chu <[email protected]>
1 parent 6edcfd5 commit 8e53070

File tree

2 files changed

+125
-18
lines changed

2 files changed

+125
-18
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4298,14 +4298,78 @@ def aten_index_put(
42984298
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
42994299
"""
43004300

4301-
# TODO(justinchuby): Handle when indicies has more than one element
4302-
index = indices[0]
4303-
new_index = op.Unsqueeze(index, [-1])
4301+
def _make_reshape_list_broadcastable(reshape_list, values_shape):
4302+
# Remove ones until the rank of reshape_list matches values_shape.
4303+
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
4304+
reshape_list.remove(1)
4305+
4306+
# Now ensure each dimension is broadcastable:
4307+
# This is mandatory when mixing basic and advanced indexing
4308+
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
4309+
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
4310+
for i, r in enumerate(reshape_list):
4311+
if r not in (1, values_shape[i]):
4312+
value_index = values_shape.index(r)
4313+
# Swap elements
4314+
# For the example above the current reshape list is [1, 2] for last dim,
4315+
# to make it broadcastable, we swap the elements
4316+
reshape_list[value_index], reshape_list[i] = r, 1
4317+
4318+
return reshape_list
4319+
4320+
# Ensure the number of indices matches the tensor rank.
4321+
self_rank = len(self.shape)
4322+
if len(indices) < self_rank:
4323+
indices = list(indices) + [None] * (self_rank - len(indices))
4324+
4325+
# Get values shape
4326+
values_shape = tuple(values.shape)
4327+
4328+
index_vectors = []
4329+
for i in range(self_rank):
4330+
if indices[i] is None:
4331+
# For a full slice along dim i, create a range index [0, self.shape[i]).
4332+
idx = op.Range(0, self.shape[i], 1)
4333+
reshape_update = self.shape[i]
4334+
else:
4335+
idx = indices[i]
4336+
reshape_update = math.prod(idx.shape)
4337+
# when Index is more than 1D, flatten it and also the values shape
4338+
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
4339+
# Indices -> (2*4,) and values shape (2*4, 32)
4340+
if len(idx.shape) > 1:
4341+
values_shape = (reshape_update,) + values_shape[len(idx.shape) :]
4342+
4343+
# Flatten index (always working with 1D index in each dim)
4344+
idx = op.Reshape(idx, [-1])
4345+
4346+
# Create a reshape pattern: one value per index dimension,
4347+
# with the current dimension set to the update size.
4348+
reshape_list = [1] * len(indices)
4349+
reshape_list[i] = reshape_update
4350+
4351+
# Adjust the reshape list to match the values shape.
4352+
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)
4353+
4354+
# Reshape and expand the index.
4355+
idx = op.Reshape(idx, reshape_list)
4356+
idx = op.Expand(idx, values_shape)
4357+
4358+
# Flatten the index to 1D and unsqueeze to form a column vector.
4359+
idx = op.Reshape(idx, [-1])
4360+
idx = op.Unsqueeze(idx, axes=[1])
4361+
index_vectors.append(idx)
4362+
4363+
# Concatenate the index vectors along axis=1 to form the final indices.
4364+
new_index = op.Concat(*index_vectors, axis=1)
4365+
4366+
# Flatten values to match the indices
4367+
flat_values = op.Reshape(values, [-1])
43044368

43054369
if accumulate:
4306-
result = op.ScatterND(self, new_index, values, reduction="add")
4370+
result = op.ScatterND(self, new_index, flat_values, reduction="add")
43074371
else:
4308-
result = op.ScatterND(self, new_index, values)
4372+
result = op.ScatterND(self, new_index, flat_values)
43094373

43104374
return result
43114375

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -790,20 +790,63 @@ def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
790790
del op_info
791791
del kwargs
792792

793-
data = torch_testing.make_tensor(
794-
(10, 3),
795-
device=device,
796-
dtype=dtype,
797-
requires_grad=requires_grad,
798-
)
799-
indices = [torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))]
800-
values = torch_testing.make_tensor(
801-
(2, 4, 3),
802-
device=device,
803-
dtype=dtype,
804-
requires_grad=requires_grad,
793+
make_arg = functools.partial(
794+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
805795
)
806-
yield opinfo_core.SampleInput(data, indices, values)
796+
797+
cases = [
798+
# Cases: one None
799+
((1, 3, 4), [None, torch.arange(2, device=device), None], (1, 2, 4)),
800+
((10, 3, 4), [torch.arange(5, device=device), None, None], (5, 3, 4)),
801+
((10, 3, 4, 6), [None, None, None, torch.arange(3, device=device)], (10, 3, 4, 3)),
802+
# Cases: two None
803+
(
804+
(10, 3, 4),
805+
[None, torch.arange(3, device=device), torch.arange(3, device=device)],
806+
(10, 3),
807+
),
808+
(
809+
(10, 3, 4, 6),
810+
[
811+
torch.arange(2, device=device),
812+
None,
813+
torch.arange(2, device=device),
814+
torch.arange(2, device=device),
815+
],
816+
(2, 3),
817+
),
818+
(
819+
(10, 3, 4),
820+
[torch.arange(2, device=device), torch.arange(2, device=device), None],
821+
(2, 4),
822+
),
823+
# Cases: Single indexing
824+
((10, 3, 4), [None, None, torch.tensor([0], device=device)], (10, 3, 1)),
825+
((10, 3, 4), [torch.tensor([0], device=device), None, None], (1, 3, 4)),
826+
((10, 3, 4, 6), [None, torch.tensor([0], device=device), None, None], (10, 1, 4, 6)),
827+
# Cases: Single element
828+
(
829+
(10, 3, 4),
830+
[
831+
torch.tensor([0], device=device),
832+
torch.tensor([0], device=device),
833+
torch.tensor([0], device=device),
834+
],
835+
(1,),
836+
),
837+
# Cases: Multidimensional index
838+
(
839+
(10, 3),
840+
[torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))],
841+
(2, 4, 3),
842+
),
843+
]
844+
845+
for data_shape, indices, values_shape in cases: # type: ignore[misc]
846+
data = make_arg(data_shape)
847+
values = make_arg(values_shape) # type: ignore[has-type]
848+
849+
yield opinfo_core.SampleInput(data, indices, values)
807850

808851

809852
def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):

0 commit comments

Comments
 (0)