diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 05e2cd9258..88700d3d97 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5217,24 +5217,47 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor: @torch_op(("aten::masked_scatter"), trace_only=True) def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor: - """masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor""" - + """masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + + Scatters values from source tensor into self tensor at positions where mask is True. + + Key behavior: Values from source are consumed sequentially (from flattened source), + not element-wise. This is why we cannot use the simpler approach: + `self * (1-mask) + mask * source` as it would assume element-wise correspondence. + + Args: + self: Target tensor to scatter into + mask: Boolean mask indicating where to scatter + source: Source tensor (flattened and consumed sequentially) + + Returns: + Tensor with values scattered from source into positions where mask is True + """ + # Ensure self and mask have compatible shapes through broadcasting if len(mask.shape) < len(self.shape): mask = op.Expand(mask, op.Shape(self)) else: self = op.Expand(self, op.Shape(mask)) - index = op.Transpose(op.NonZero(mask), perm=[1, 0]) - - # NOTE: source can have more elements than needed. - # It could also have arbitrary shape. - # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor. - source = op.Reshape(source, op.Constant(value_ints=[-1])) - axes = op.Constant(value_ints=[0]) - starts = op.Constant(value_ints=[0]) - ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0) - source = op.Slice(source, starts, ends, axes) + + # Get indices where mask is True + # NonZero returns [num_dims, num_true_positions], transpose to [num_true_positions, num_dims] + true_indices = op.Transpose(op.NonZero(mask), perm=[1, 0]) + + # Prepare source values for scattering + # ScatterND requires source to have exactly the right number of elements, + # but source can have arbitrary shape and may contain more elements than needed. + flattened_source = op.Reshape(source, op.Constant(value_ints=[-1])) + + # Slice source to get exactly the number of values we need + num_true_positions = op.Gather(op.Shape(true_indices), op.Constant(value_ints=[0]), axis=0) + source_values = op.Slice( + flattened_source, + starts=op.Constant(value_ints=[0]), + ends=num_true_positions, + axes=op.Constant(value_ints=[0]) + ) - return op.ScatterND(self, index, source) + return op.ScatterND(self, true_indices, source_values) def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType: