Skip to content

Commit bb129e6

Browse files
feat(atenlib): add op(index put) (#522)
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent bff0b52 commit bb129e6

File tree

2 files changed

+78
-5
lines changed

2 files changed

+78
-5
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,15 +2700,74 @@ def aten_index_copy(
27002700
raise NotImplementedError()
27012701

27022702

2703+
@torch_op("aten::index_put")
27032704
def aten_index_put(
2704-
self: TensorType,
2705-
indices: Optional[Sequence[TensorType]],
2706-
values: TensorType,
2705+
self: TReal,
2706+
indices: Sequence[INT64],
2707+
values: TReal,
27072708
accumulate: bool = False,
2708-
) -> TensorType:
2709+
) -> TReal:
27092710
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
27102711

2711-
raise NotImplementedError()
2712+
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
2713+
# change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]])
2714+
self_dim_1 = op.Gather(op.Shape(self), 1)
2715+
index_dim_0 = op.Gather(op.Shape(index), 0)
2716+
neg_1 = op.Constant(value_ints=[-1])
2717+
shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0)
2718+
new_ind = op.Expand(index, shape)
2719+
new_ind_t = op.Transpose(new_ind)
2720+
2721+
if op.Cast(accumulate, to=BOOL.dtype):
2722+
# put values into zeros array first, then add to input
2723+
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
2724+
result = op.ScatterElements(zeros, new_ind_t, values)
2725+
result = op.Add(result, self)
2726+
else:
2727+
result = op.ScatterElements(self, new_ind_t, values)
2728+
return result
2729+
2730+
2731+
@torch_op("aten::index_put_bool", overload=True)
2732+
def aten_index_put_bool(
2733+
self: TReal,
2734+
indices: Sequence[BOOL],
2735+
values: TReal,
2736+
accumulate: bool = False,
2737+
) -> TReal:
2738+
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""
2739+
2740+
index = op.SequenceAt(indices, 0) # assume indices only have 1 element
2741+
# FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it
2742+
index_int = op.Cast(index, to=INT32.dtype)
2743+
# if all False, return self
2744+
if op.ReduceSum(index_int) == 0:
2745+
result = self
2746+
else:
2747+
# change array([F,F,T,F,F]) to array([2])
2748+
index = op.ArgMax(index_int) # assume index only have 1 True
2749+
# change array([2]) to array([2,2,2,2,2])
2750+
self_dim_1 = op.Gather(op.Shape(self), 1)
2751+
index_dim_0 = op.Gather(op.Shape(index), 0)
2752+
neg_1 = op.Constant(value_ints=[-1])
2753+
shape = op.Concat(
2754+
op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0
2755+
)
2756+
new_ind = op.Expand(index, shape)
2757+
new_ind_t = op.Transpose(new_ind)
2758+
2759+
# values must have same rank with input(self)
2760+
if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator]
2761+
values = op.Unsqueeze(values, op.Constant(value_ints=[0]))
2762+
2763+
if op.Cast(accumulate, to=BOOL.dtype):
2764+
zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self))
2765+
result = op.ScatterElements(zeros, new_ind_t, values)
2766+
result = op.Add(result, self)
2767+
else:
2768+
result = op.ScatterElements(self, new_ind_t, values)
2769+
2770+
return result
27122771

27132772

27142773
def aten_index_reduce(

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def _where_input_wrangler(
404404
"full_like": core_ops.aten_full_like,
405405
"ge": core_ops.aten_ge,
406406
"gt": core_ops.aten_gt,
407+
"index_put_bool": core_ops.aten_index_put_bool,
408+
"index_put": core_ops.aten_index_put,
407409
"isclose": core_ops.aten_isclose,
408410
"isfinite": core_ops.aten_isfinite,
409411
"isinf": core_ops.aten_isinf,
@@ -680,6 +682,16 @@ def _where_input_wrangler(
680682
matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None,
681683
reason="rounding_mode is not yet supported",
682684
),
685+
skip(
686+
"index_put",
687+
matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64),
688+
reason="this Aten overload only support tensor(int) as args",
689+
),
690+
skip(
691+
"index_put_bool",
692+
matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool),
693+
reason="this Aten overload only support tensor(bool) as args",
694+
),
683695
skip(
684696
"min", # aten_mean
685697
matcher=lambda sample: len(sample.args) > 0,
@@ -797,6 +809,8 @@ def _where_input_wrangler(
797809
),
798810
)
799811

812+
duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
813+
800814
duplicate_opinfo(OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",))
801815

802816
duplicate_opinfo(

0 commit comments

Comments
 (0)