diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index c3f14790f3..4a338ecc1d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5827,10 +5827,16 @@ def aten_select_backward( raise NotImplementedError() +@torch_op("aten::select_scatter") def aten_select_scatter(self: TensorType, src: TensorType, dim: int, index: int) -> TensorType: """select_scatter(Tensor self, Tensor src, int dim, int index) -> Tensor""" - raise NotImplementedError() + # Change src rank to self rank according to dim + # e.g. if self is [2,3,4], src is [2,4], dim=1, then update is [2,1,4] + update = op.Unsqueeze(src, axes=dim) + # Change index rank to the same as 'update' [2,1,4] + indices = op.Expand(index, op.Shape(update)) + return op.ScatterElements(self, indices, update, axis=dim, reduction="none") @torch_op("aten::selu") diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index ffe3da3a23..4c7b027984 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -1063,6 +1063,7 @@ def _where_input_wrangler( reason="fixme: ORT failed", ), TorchLibOpInfo("select", core_ops.aten_select), + TorchLibOpInfo("select_scatter", core_ops.aten_select_scatter), TorchLibOpInfo("sigmoid", core_ops.aten_sigmoid), TorchLibOpInfo("sign", core_ops.aten_sign), TorchLibOpInfo("sin", core_ops.aten_sin),