Skip to content

Commit 94d820a

Browse files
author
AyoubMDL
committed
fix(index_put): handle None cases
1 parent 5c31a7e commit 94d820a

File tree

2 files changed

+63
-19
lines changed

2 files changed

+63
-19
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,7 +1647,8 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
16471647
dim_size = op.Gather(self_shape, dim, axis=0)
16481648
# Compute size/chunk to get the number of data in one chunk
16491649
num_per_chunk = op.Div(dim_size, chunks)
1650-
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1650+
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + \
1651+
num_per_chunk # type: ignore[operator]
16511652

16521653
# Compute real chunk number
16531654
num_chunk = op.Div(dim_size, num_per_chunk)
@@ -4259,15 +4260,51 @@ def aten_index_put(
42594260
See implementation of `torch.onnx.symbolic_opset11.index_put
42604261
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
42614262
"""
4263+
# Pad indices with None so It has the same rank as self
4264+
self_rank = len(self.shape)
4265+
if len(indices) < self_rank:
4266+
indices = list(indices) + [None] * (self_rank - len(indices))
4267+
4268+
values_shape = values.shape.numpy()
4269+
# Pad values_shape with 1 so It has the same rank as self
4270+
if len(values_shape) < self_rank:
4271+
values_shape = (1,) * (self_rank - len(values_shape)) + values_shape
4272+
4273+
index_vectors = []
4274+
for i, index in enumerate(indices):
4275+
if index is None:
4276+
# For a full slice, create a range.
4277+
index_vector = op.Range(start=0, limit=values_shape[i], delta=1)
4278+
else:
4279+
index_vector = index
42624280

4263-
# TODO(justinchuby): Handle when indicies has more than one element
4264-
index = indices[0]
4265-
new_index = op.Unsqueeze(index, [-1])
4281+
# Shape vector with 1s, except at axis i.
4282+
shape_vector = [1] * self_rank
4283+
shape_vector[i] = values_shape[i]
4284+
4285+
# Reshape index_vector so that only the i-th dimension matches values_shape[i]
4286+
reshaped_index_vector = op.Reshape(index_vector, shape_vector)
4287+
4288+
# Expand reshaped_index_vector to match the full shape of values
4289+
expanded_index_vector = op.Expand(reshaped_index_vector, values_shape)
4290+
4291+
# Flatten into a 1D vector
4292+
column_index_vector = op.Reshape(expanded_index_vector, [-1])
4293+
4294+
# Convert into a column vector to prepare for concatenation
4295+
column_index_vector = op.Unsqueeze(column_index_vector, axes=[1])
4296+
index_vectors.append(column_index_vector)
4297+
4298+
# Contains all indices to be upadated
4299+
new_index = op.Concat(*index_vectors, axis=1)
4300+
4301+
# Flatten values to match the indices
4302+
flat_values = op.Reshape(values, [-1])
42664303

42674304
if accumulate:
4268-
result = op.ScatterND(self, new_index, values, reduction="add")
4305+
result = op.ScatterND(self, new_index, flat_values, reduction="add")
42694306
else:
4270-
result = op.ScatterND(self, new_index, values)
4307+
result = op.ScatterND(self, new_index, flat_values)
42714308

42724309
return result
42734310

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -790,20 +790,27 @@ def sample_inputs_index_put(op_info, device, dtype, requires_grad, **kwargs):
790790
del op_info
791791
del kwargs
792792

793-
data = torch_testing.make_tensor(
794-
(10, 3),
795-
device=device,
796-
dtype=dtype,
797-
requires_grad=requires_grad,
798-
)
799-
indices = [torch.arange(8, dtype=torch.int64, device=device).reshape((-1, 4))]
800-
values = torch_testing.make_tensor(
801-
(2, 4, 3),
802-
device=device,
803-
dtype=dtype,
804-
requires_grad=requires_grad,
793+
make_arg = functools.partial(
794+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
805795
)
806-
yield opinfo_core.SampleInput(data, indices, values)
796+
797+
data = make_arg((10, 3, 4))
798+
799+
cases = [
800+
# Case 1: Full slices in dims 0 and 2, tensor index in dim 1
801+
([None, torch.arange(2, device=device), None], (10, 2, 4)),
802+
# Case 2: Tensor index in dim 0, full slices in dims 1 and 2
803+
([torch.arange(5, device=device), None, None], (5, 3, 4)),
804+
# Case 3: Full slices in dims 0 and 1, tensor index in dim 2
805+
([None, None, torch.arange(3, device=device)], (10, 3, 3)),
806+
# Case 4: Single index in last dimension
807+
([None, None, torch.tensor([0], device=device)], (10, 3, 1)),
808+
]
809+
810+
for indices, values_shape in cases: # type: ignore[misc]
811+
values = make_arg(values_shape) # type: ignore[has-type]
812+
813+
yield opinfo_core.SampleInput(data, indices, values)
807814

808815

809816
def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):

0 commit comments

Comments
 (0)