Skip to content

feat(atenlib): add op(index put) #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2700,15 +2700,74 @@ def aten_index_copy(
raise NotImplementedError()


@torch_op("aten::index_put")
def aten_index_put(
self: TensorType,
indices: Optional[Sequence[TensorType]],
values: TensorType,
self: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TensorType:
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

raise NotImplementedError()
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am bit unclear about the assumption above. Is this just for now? Is the input type expected to be changed to INT64 ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like indices is a list of 1d tensors, so INT64 may not do it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On this line, was there a reason we assume indices only has one element?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the test case, all the indices have only one element. If it contains multiple elements, we need to use for/while to do the logic.
But for this aten op, the multiple elements can be combined to one element, for example, indices=[[0,1], [2,3]], can be combined to indices=[[0,1,2,3],].

Copy link
Collaborator

@justinchuby justinchuby Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do a sequenceConcat or something like that?

Copy link
Contributor Author

@xiaowuhu xiaowuhu Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried in Torch, actually it cannot work like this:
torch.index_put( self=tensor(5x5), indices=(tensor([0]), tensor([1]),), values=tensor(2x5))
so we don't know how to use more than one element in indices tuple.

# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

if op.Cast(accumulate, to=BOOL.dtype):
# put values into zeros array first, then add to input
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
result = op.ScatterElements(zeros, new_ind_t, values)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)
return result


@torch_op("aten::index_put_bool", overload=True)
def aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

index = op.SequenceAt(indices, 0) # assume indices only have 1 element
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
index_int = op.Cast(index, to=INT32.dtype)
# if all False, return self
if op.ReduceSum(index_int) == 0:
result = self
else:
# change array([F,F,T,F,F]) to array([2])
index = op.ArgMax(index_int) # assume index only have 1 True
# change array([2]) to array([2,2,2,2,2])
self_dim_1 = op.Gather(op.Shape(self), 1)
index_dim_0 = op.Gather(op.Shape(index), 0)
neg_1 = op.Constant(value_ints=[-1])
shape = op.Concat(
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
)
new_ind = op.Expand(index, shape)
new_ind_t = op.Transpose(new_ind)

# values must have same rank with input(self)
if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator]
values = op.Unsqueeze(values, op.Constant(value_ints=[0]))

if op.Cast(accumulate, to=BOOL.dtype):
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
result = op.ScatterElements(zeros, new_ind_t, values)
result = op.Add(result, self)
else:
result = op.ScatterElements(self, new_ind_t, values)

return result


def aten_index_reduce(
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ def _where_input_wrangler(
"full_like": core_ops.aten_full_like,
"ge": core_ops.aten_ge,
"gt": core_ops.aten_gt,
"index_put_bool": core_ops.aten_index_put_bool,
"index_put": core_ops.aten_index_put,
"isclose": core_ops.aten_isclose,
"isfinite": core_ops.aten_isfinite,
"isinf": core_ops.aten_isinf,
Expand Down Expand Up @@ -680,6 +682,16 @@ def _where_input_wrangler(
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
reason="rounding_mode is not yet supported",
),
skip(
"index_put",
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
reason="this Aten overload only support tensor(int) as args",
),
skip(
"index_put_bool",
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
reason="this Aten overload only support tensor(bool) as args",
),
skip(
"min", # aten_mean
matcher=lambda sample: len(sample.args) > 0,
Expand Down Expand Up @@ -797,6 +809,8 @@ def _where_input_wrangler(
),
)

duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))

duplicate_opinfo(OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",))

duplicate_opinfo(
Expand Down