diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index dc4de470dc..fce6124972 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -2700,15 +2700,74 @@ def aten_index_copy( raise NotImplementedError() +@torch_op("aten::index_put") def aten_index_put( - self: TensorType, - indices: Optional[Sequence[TensorType]], - values: TensorType, + self: TReal, + indices: Sequence[INT64], + values: TReal, accumulate: bool = False, -) -> TensorType: +) -> TReal: """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" - raise NotImplementedError() + index = op.SequenceAt(indices, 0) # assume indices only have 1 element + # change array([1,3]) to array([[1,1,1,1,1],[3,3,3,3,3]]) + self_dim_1 = op.Gather(op.Shape(self), 1) + index_dim_0 = op.Gather(op.Shape(index), 0) + neg_1 = op.Constant(value_ints=[-1]) + shape = op.Concat(op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0) + new_ind = op.Expand(index, shape) + new_ind_t = op.Transpose(new_ind) + + if op.Cast(accumulate, to=BOOL.dtype): + # put values into zeros array first, then add to input + zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) + result = op.ScatterElements(zeros, new_ind_t, values) + result = op.Add(result, self) + else: + result = op.ScatterElements(self, new_ind_t, values) + return result + + +@torch_op("aten::index_put_bool", overload=True) +def aten_index_put_bool( + self: TReal, + indices: Sequence[BOOL], + values: TReal, + accumulate: bool = False, +) -> TReal: + """index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor""" + + index = op.SequenceAt(indices, 0) # assume indices only have 1 element + # FIXME: ORT ArgMax fails on INT64 input even though ONNX allows it + index_int = op.Cast(index, to=INT32.dtype) + # if all False, return self + if op.ReduceSum(index_int) == 0: + result = self + else: + # change array([F,F,T,F,F]) to array([2]) + index = op.ArgMax(index_int) # assume index only have 1 True + # change array([2]) to array([2,2,2,2,2]) + self_dim_1 = op.Gather(op.Shape(self), 1) + index_dim_0 = op.Gather(op.Shape(index), 0) + neg_1 = op.Constant(value_ints=[-1]) + shape = op.Concat( + op.Reshape(self_dim_1, neg_1), op.Reshape(index_dim_0, neg_1), axis=0 + ) + new_ind = op.Expand(index, shape) + new_ind_t = op.Transpose(new_ind) + + # values must have same rank with input(self) + if op.Size(op.Shape(values)) < op.Size(op.Shape(self)): # type: ignore[operator] + values = op.Unsqueeze(values, op.Constant(value_ints=[0])) + + if op.Cast(accumulate, to=BOOL.dtype): + zeros = op.Expand(op.Constant(value_float=0.0), op.Shape(self)) + result = op.ScatterElements(zeros, new_ind_t, values) + result = op.Add(result, self) + else: + result = op.ScatterElements(self, new_ind_t, values) + + return result def aten_index_reduce( diff --git a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py index 8b2a424458..dc7bede1dd 100644 --- a/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py +++ b/onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py @@ -404,6 +404,8 @@ def _where_input_wrangler( "full_like": core_ops.aten_full_like, "ge": core_ops.aten_ge, "gt": core_ops.aten_gt, + "index_put_bool": core_ops.aten_index_put_bool, + "index_put": core_ops.aten_index_put, "isclose": core_ops.aten_isclose, "isfinite": core_ops.aten_isfinite, "isinf": core_ops.aten_isinf, @@ -680,6 +682,16 @@ def _where_input_wrangler( matcher=lambda sample: sample.kwargs.get("rounding_mode") is not None, reason="rounding_mode is not yet supported", ), + skip( + "index_put", + matcher=lambda sample: not (sample.args[0][0].dtype == torch.int64), + reason="this Aten overload only support tensor(int) as args", + ), + skip( + "index_put_bool", + matcher=lambda sample: not (sample.args[0][0].dtype == torch.bool), + reason="this Aten overload only support tensor(bool) as args", + ), skip( "min", # aten_mean matcher=lambda sample: len(sample.args) > 0, @@ -797,6 +809,8 @@ def _where_input_wrangler( ), ) +duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) + duplicate_opinfo(OPS_DB, "nn.functional.nll_loss", ("nn.functional.nll_loss_weight",)) duplicate_opinfo(