diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2bdea7ca5f..d2648d94a4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5202,10 +5202,26 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: return op.Where(mask, value_cast, self) -def aten_masked_scatter(self: TensorType, mask: TensorType, source: TensorType) -> TensorType: +@torch_op(("aten::masked_scatter"), trace_only=True) +def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor: """masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor""" - raise NotImplementedError() + if len(mask.shape) < len(self.shape): + mask = op.Expand(mask, op.Shape(self)) + else: + self = op.Expand(self, op.Shape(mask)) + index = op.Transpose(op.NonZero(mask), perm=[1, 0]) + + # NOTE: source can have more elements than needed. + # It could also have arbitrary shape. + # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. + source = op.Reshape(source, op.Constant(value_ints=[-1])) + axes = op.Constant(value_ints=[0]) + starts = op.Constant(value_ints=[0]) + ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0) + source = op.Slice(source, starts, ends, axes) + + return op.ScatterND(self, index, source) def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType: @@ -6429,7 +6445,7 @@ def aten_nextafter(self: TensorType, other: TensorType) -> TensorType: raise NotImplementedError() -@torch_op("aten::nonzero") +@torch_op("aten::nonzero", trace_only=True) def aten_nonzero(self: TTensor) -> INT64: """nonzero(Tensor self) -> Tensor""" # NOTE: In torch the return shape is [n, d], while in onnx [d, n], diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index e3be105839..e8ccc87aea 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -932,6 +932,7 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not have an implementation for Where with bool inputs.", ), + TorchLibOpInfo("masked_scatter", core_ops.aten_masked_scatter), TorchLibOpInfo( "matmul", core_ops.aten_matmul,