Skip to content

Improve masked_scatter implementation documentation and code clarity #2387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5217,24 +5217,47 @@ def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:

@torch_op(("aten::masked_scatter"), trace_only=True)
def aten_masked_scatter(self: TTensor, mask: TTensor, source: TTensor) -> TTensor:
"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"""

"""masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor

Scatters values from source tensor into self tensor at positions where mask is True.

Key behavior: Values from source are consumed sequentially (from flattened source),
not element-wise. This is why we cannot use the simpler approach:
`self * (1-mask) + mask * source` as it would assume element-wise correspondence.

Args:
self: Target tensor to scatter into
mask: Boolean mask indicating where to scatter
source: Source tensor (flattened and consumed sequentially)

Returns:
Tensor with values scattered from source into positions where mask is True
"""
# Ensure self and mask have compatible shapes through broadcasting
if len(mask.shape) < len(self.shape):
mask = op.Expand(mask, op.Shape(self))
else:
self = op.Expand(self, op.Shape(mask))
index = op.Transpose(op.NonZero(mask), perm=[1, 0])

# NOTE: source can have more elements than needed.
# It could also have arbitrary shape.
# This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
source = op.Reshape(source, op.Constant(value_ints=[-1]))
axes = op.Constant(value_ints=[0])
starts = op.Constant(value_ints=[0])
ends = op.Gather(op.Shape(index), op.Constant(value_ints=[0]), axis=0)
source = op.Slice(source, starts, ends, axes)

# Get indices where mask is True
# NonZero returns [num_dims, num_true_positions], transpose to [num_true_positions, num_dims]
true_indices = op.Transpose(op.NonZero(mask), perm=[1, 0])

# Prepare source values for scattering
# ScatterND requires source to have exactly the right number of elements,
# but source can have arbitrary shape and may contain more elements than needed.
flattened_source = op.Reshape(source, op.Constant(value_ints=[-1]))

# Slice source to get exactly the number of values we need
num_true_positions = op.Gather(op.Shape(true_indices), op.Constant(value_ints=[0]), axis=0)
source_values = op.Slice(
flattened_source,
starts=op.Constant(value_ints=[0]),
ends=num_true_positions,
axes=op.Constant(value_ints=[0])
)

return op.ScatterND(self, index, source)
return op.ScatterND(self, true_indices, source_values)


def aten_masked_select(self: TensorType, mask: TensorType) -> TensorType:
Expand Down