-
Notifications
You must be signed in to change notification settings - Fork 72
feat(atenlib): add op(index put) #522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3c77a7
f3840de
1253c78
ab3f3fa
8f18989
669aacb
f1d0b25
0e361f5
73f265e
8dc58c8
258e0b1
4e087ab
2ef68a4
40206a0
9d2b54b
a2508f8
5d424c6
f885ea6
3e751a9
b0e82c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2700,15 +2700,74 @@ def aten_index_copy( | |
raise NotImplementedError() | ||
|
||
|
||
@torch_op("aten::index_put") | ||
def aten_index_put( | ||
self: TensorType, | ||
indices: Optional[Sequence[TensorType]], | ||
values: TensorType, | ||
self: TReal, | ||
indices: Sequence[INT64], | ||
values: TReal, | ||
accumulate: bool = False, | ||
) -> TensorType: | ||
) -> TReal: | ||
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" | ||
|
||
raise NotImplementedError() | ||
index = op.SequenceAt(indices, 0) # assume indices only have 1 element | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am bit unclear about the assumption above. Is this just for now? Is the input type expected to be changed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like indices is a list of 1d tensors, so INT64 may not do it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On this line, was there a reason we assume indices only has one element? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the test case, all the indices have only one element. If it contains multiple elements, we need to use for/while to do the logic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do a sequenceConcat or something like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried in Torch, actually it cannot work like this: |
||
# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]]) | ||
self_dim_1 = op.Gather(op.Shape(self), 1) | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
index_dim_0 = op.Gather(op.Shape(index), 0) | ||
neg_1 = op.Constant(value_ints=[-1]) | ||
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0) | ||
new_ind = op.Expand(index, shape) | ||
new_ind_t = op.Transpose(new_ind) | ||
xiaowuhu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if op.Cast(accumulate, to=BOOL.dtype): | ||
# put values into zeros array first, then add to input | ||
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) | ||
result = op.ScatterElements(zeros, new_ind_t, values) | ||
result = op.Add(result, self) | ||
else: | ||
result = op.ScatterElements(self, new_ind_t, values) | ||
return result | ||
|
||
|
||
@torch_op("aten::index_put_bool", overload=True) | ||
def aten_index_put_bool( | ||
self: TReal, | ||
indices: Sequence[BOOL], | ||
values: TReal, | ||
accumulate: bool = False, | ||
) -> TReal: | ||
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" | ||
|
||
index = op.SequenceAt(indices, 0) # assume indices only have 1 element | ||
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it | ||
index_int = op.Cast(index, to=INT32.dtype) | ||
# if all False, return self | ||
if op.ReduceSum(index_int) == 0: | ||
result = self | ||
else: | ||
# change array([F,F,T,F,F]) to array([2]) | ||
index = op.ArgMax(index_int) # assume index only have 1 True | ||
# change array([2]) to array([2,2,2,2,2]) | ||
self_dim_1 = op.Gather(op.Shape(self), 1) | ||
index_dim_0 = op.Gather(op.Shape(index), 0) | ||
neg_1 = op.Constant(value_ints=[-1]) | ||
shape = op.Concat( | ||
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 | ||
) | ||
new_ind = op.Expand(index, shape) | ||
new_ind_t = op.Transpose(new_ind) | ||
|
||
# values must have same rank with input(self) | ||
if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator] | ||
values = op.Unsqueeze(values, op.Constant(value_ints=[0])) | ||
|
||
if op.Cast(accumulate, to=BOOL.dtype): | ||
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) | ||
result = op.ScatterElements(zeros, new_ind_t, values) | ||
result = op.Add(result, self) | ||
else: | ||
result = op.ScatterElements(self, new_ind_t, values) | ||
|
||
return result | ||
|
||
|
||
def aten_index_reduce( | ||
|
Uh oh!
There was an error while loading. Please reload this page.