Skip to content

[torchlib] Fix index_put: handle None cases #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4298,14 +4298,78 @@
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""

# TODO(justinchuby): Handle when indicies has more than one element
index = indices[0]
new_index = op.Unsqueeze(index, [-1])
def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
reshape_list.remove(1)

Check warning on line 4304 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4304

Added line #L4304 was not covered by tests

# Now ensure each dimension is broadcastable:
# This is mandatory when mixing basic and advanced indexing
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
for i, r in enumerate(reshape_list):
if r not in (1, values_shape[i]):
value_index = values_shape.index(r)

Check warning on line 4312 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4312

Added line #L4312 was not covered by tests
# Swap elements
# For the example above the current reshape list is [1, 2] for last dim,
# to make it broadcastable, we swap the elements
reshape_list[value_index], reshape_list[i] = r, 1

Check warning on line 4316 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4316

Added line #L4316 was not covered by tests

return reshape_list

# Ensure the number of indices matches the tensor rank.
self_rank = len(self.shape)
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))

# Get values shape
values_shape = tuple(values.shape)

index_vectors = []
for i in range(self_rank):
if indices[i] is None:
# For a full slice along dim i, create a range index [0, self.shape[i]).
idx = op.Range(0, op.Shape(self, start=i, end=i + 1), 1)
reshape_update = self.shape[i]
else:
idx = indices[i]
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update,) + values_shape[len(idx.shape) :]

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])

# Create a reshape pattern: one value per index dimension,
# with the current dimension set to the update size.
reshape_list = [1] * len(indices)
reshape_list[i] = reshape_update

# Adjust the reshape list to match the values shape.
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)

# Reshape and expand the index.
idx = op.Reshape(idx, reshape_list)
idx = op.Expand(idx, values_shape)

# Flatten the index to 1D and unsqueeze to form a column vector.
idx = op.Reshape(idx, [-1])
idx = op.Unsqueeze(idx, axes=[1])
index_vectors.append(idx)

# Concatenate the index vectors along axis=1 to form the final indices.
new_index = op.Concat(*index_vectors, axis=1)

# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, values, reduction="add")
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, values)
result = op.ScatterND(self, new_index, flat_values)

Check warning on line 4372 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4372

Added line #L4372 was not covered by tests

return result

Expand Down
69 changes: 56 additions & 13 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,63 @@ def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs

data = torch_testing.make_tensor(
(10, 3),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
indices = [torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))]
values = torch_testing.make_tensor(
(2, 4, 3),
device=device,
dtype=dtype,
requires_grad=requires_grad,
make_arg = functools.partial(
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
yield opinfo_core.SampleInput(data, indices, values)

cases = [
# Cases: one None
((1, 3, 4), [None, torch.arange(2, device=device), None], (1, 2, 4)),
((10, 3, 4), [torch.arange(5, device=device), None, None], (5, 3, 4)),
((10, 3, 4, 6), [None, None, None, torch.arange(3, device=device)], (10, 3, 4, 3)),
# Cases: two None
(
(10, 3, 4),
[None, torch.arange(3, device=device), torch.arange(3, device=device)],
(10, 3),
),
(
(10, 3, 4, 6),
[
torch.arange(2, device=device),
None,
torch.arange(2, device=device),
torch.arange(2, device=device),
],
(2, 3),
),
(
(10, 3, 4),
[torch.arange(2, device=device), torch.arange(2, device=device), None],
(2, 4),
),
# Cases: Single indexing
((10, 3, 4), [None, None, torch.tensor([0], device=device)], (10, 3, 1)),
((10, 3, 4), [torch.tensor([0], device=device), None, None], (1, 3, 4)),
((10, 3, 4, 6), [None, torch.tensor([0], device=device), None, None], (10, 1, 4, 6)),
# Cases: Single element
(
(10, 3, 4),
[
torch.tensor([0], device=device),
torch.tensor([0], device=device),
torch.tensor([0], device=device),
],
(1,),
),
# Cases: Multidimensional index
(
(10, 3),
[torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))],
(2, 4, 3),
),
]

for data_shape, indices, values_shape in cases: # type: ignore[misc]
data = make_arg(data_shape)
values = make_arg(values_shape) # type: ignore[has-type]

yield opinfo_core.SampleInput(data, indices, values)


def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
Expand Down
Loading