Skip to content

Commit 1317cdb

Browse files
YifanShenSZyifan_shen3
andauthored
[ExecuTorch] Support Several Indexing Ops (#2190)
* polish mb.scatter_nd doc: fix some errors, de-duplicate symbols, add an example to make it clearer * support several indexing ops: add select_scatter, add slice_scatter, improve index_put; polish copy and reshape ops along the way --------- Co-authored-by: yifan_shen3 <[email protected]>
1 parent 1083cf5 commit 1317cdb

File tree

4 files changed

+571
-107
lines changed

4 files changed

+571
-107
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 177 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1676,6 +1676,14 @@ def view(context, node):
16761676
x = inputs[0]
16771677
shape = inputs[1]
16781678

1679+
if np.prod(shape.shape) == 0:
1680+
# Reshape to empty shape (works only for scalar) is a no op
1681+
assert (
1682+
np.prod(x.shape) <= 1
1683+
), "Reshape to empty shape works only for scalar and single-element tensor"
1684+
context.add(mb.identity(x=x, name=node.name))
1685+
return
1686+
16791687
if isinstance(shape, ListVar):
16801688
length = mb.list_length(ls=shape)
16811689
indices = mb.range_1d(start=0, end=length, step=1)
@@ -3759,39 +3767,180 @@ def _internal_op_tensor_inplace_fill(context, node):
37593767

37603768

37613769
@register_torch_op
3762-
def index_put(context, node):
3770+
def select_scatter(context, node):
37633771
inputs = _get_inputs(context, node, expected=4)
37643772
x = inputs[0]
3773+
updates = inputs[1]
3774+
dim = inputs[2].val
3775+
if dim is None:
3776+
raise ValueError("Only compile time known dim supported yet")
3777+
index = inputs[3]
3778+
3779+
# mb.torch_tensor_assign handles multi-dim slicing
3780+
# so we need to create slice specifications for all other dimensions
3781+
begin = [0] * x.rank
3782+
begin[dim] = index
3783+
begin = mb.concat(values=begin, axis=0)
3784+
end = x.shape
3785+
# and squeeze dim to do pure indexing on it
3786+
squeeze_mask = [False] * x.rank
3787+
squeeze_mask[dim] = True
3788+
3789+
updated_x = _translate_torch_tensor_assign(
3790+
x=x,
3791+
updates=updates,
3792+
begin=begin,
3793+
end=end,
3794+
stride=None,
3795+
begin_mask=None,
3796+
end_mask=None,
3797+
squeeze_mask=squeeze_mask,
3798+
name=node.name,
3799+
)
3800+
context.add(updated_x)
3801+
3802+
3803+
@register_torch_op
3804+
def slice_scatter(context, node):
3805+
inputs = _get_inputs(context, node, min_expected=2)
3806+
x, updates = promote_input_dtypes(inputs[0:2])
3807+
dim = 0 if len(inputs) <= 2 else inputs[2].val
3808+
if dim is None:
3809+
raise ValueError("Only compile time known dim supported yet")
3810+
start = 0 if len(inputs) <= 3 else inputs[3]
3811+
end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim])
3812+
step = 1 if len(inputs) <= 5 else inputs[5]
3813+
3814+
assert dim is not None, "slice dim must be known at compile time"
3815+
assert 0 <= dim and dim < x.rank
3816+
3817+
# mb.torch_tensor_assign handles multi-dim slicing
3818+
# so we need to pad start, end, step from scalar to x.rank
3819+
starts = [0] * x.rank
3820+
starts[dim] = start
3821+
starts = mb.concat(values=starts, axis=0)
3822+
ends = list(x.shape)
3823+
ends[dim] = end
3824+
ends = mb.concat(values=ends, axis=0)
3825+
steps = [1] * x.rank
3826+
steps[dim] = step
3827+
steps = mb.concat(values=steps, axis=0)
3828+
3829+
updated_x = _translate_torch_tensor_assign(
3830+
x=x,
3831+
updates=updates,
3832+
begin=starts,
3833+
end=ends,
3834+
stride=steps,
3835+
begin_mask=None,
3836+
end_mask=None,
3837+
squeeze_mask=None,
3838+
name=node.name,
3839+
)
3840+
context.add(updated_x)
3841+
3842+
3843+
@register_torch_op
3844+
def index_put(context, node):
3845+
inputs = _get_inputs(context, node, min_expected=3)
3846+
x = inputs[0]
37653847
indices = inputs[1]
37663848
values = inputs[2]
3767-
accumulate = inputs[3].val
3768-
rank = x.rank
3849+
accumulate = False if len(inputs) < 4 else inputs[3].val
37693850
mode = "add" if accumulate else "update"
37703851

3771-
indices_type = indices[0].sym_type.get_primitive()
3852+
assert isinstance(indices, list), "indices must be a list of tensors"
3853+
# Usually indices is a list of non-None tensors, so we stack them and feed to mb.scatter_nd
3854+
# However, when there exists a whole slice (i.e. :), that index is represented as None
3855+
if any(map(lambda index: index is None, indices)):
3856+
# We have 2 ways to translate such torch.index_put, both have pros and cons
3857+
# 1. mb.scatter_nd
3858+
# * pro: can handle accumulate or update
3859+
# * con: can only have whole slice at last dimensions
3860+
# 2. mb.torch_tensor_assign
3861+
# * pro: can have whole slice at arbitrary dimension
3862+
# * con: can only handle update
3863+
# Here we use mb.torch_tensor_assign
3864+
# TODO: explore how can we cover as many torch.index_put cases as possible
3865+
if accumulate:
3866+
raise NotImplementedError(
3867+
"If there existed any whole slice (e.g. : in x[:, 0]), "
3868+
"only torch.index_put(..., accumulate=False) handled yet"
3869+
)
3870+
3871+
begin = [0] * x.rank
3872+
end = list(x.shape)
3873+
stride = [1] * x.rank
3874+
begin_mask = [True] * x.rank
3875+
end_mask = [True] * x.rank
3876+
# note: in torch slice, an indexed dim becomes size 1, rather than squeezed, e.g.
3877+
# x = torch.zeros((2, 3))
3878+
# y = x[:, 1]
3879+
# we will get y.shape as (2, 1)
3880+
is_dim_unity = [False] * x.rank
3881+
for dim, index in enumerate(indices):
3882+
if index is not None:
3883+
if len(index.shape) > 0:
3884+
index = mb.squeeze(x=index)
3885+
begin[dim] = index
3886+
end[dim] = mb.add(x=index, y=1)
3887+
begin_mask[dim] = False
3888+
end_mask[dim] = False
3889+
is_dim_unity[dim] = True
3890+
begin = mb.concat(values=begin, axis=0)
3891+
end = mb.concat(values=end, axis=0)
3892+
3893+
expected_values_shape = []
3894+
for dim in range(x.rank):
3895+
expected_values_shape.append(1 if is_dim_unity[dim] else x.shape[dim])
3896+
expected_values_shape = tuple(expected_values_shape)
3897+
3898+
if values.shape != expected_values_shape:
3899+
values = _broadcast(values.name + "_broadcasted", values, expected_values_shape)
3900+
3901+
updated_x = _translate_torch_tensor_assign(
3902+
x=x,
3903+
updates=values,
3904+
begin=begin,
3905+
end=end,
3906+
stride=stride,
3907+
begin_mask=begin_mask,
3908+
end_mask=end_mask,
3909+
squeeze_mask=[False] * x.rank,
3910+
name=node.name,
3911+
)
3912+
context.add(updated_x)
3913+
return
37723914

3915+
indices_type = indices[0].sym_type.get_primitive()
37733916
if types.is_bool(indices_type):
3917+
# indices
37743918
assert len(indices) == 1, "Unsupported index_put_ usage."
37753919
indices = indices[0]
37763920
assert (
37773921
indices.shape == x.shape
37783922
), "indices shape must equal to input shape for index put operation."
37793923
indices = mb.cast(x=indices, dtype="int32")
37803924
indices = mb.non_zero(x=indices)
3781-
3782-
if types.is_int(indices_type):
3925+
# values
3926+
if values.shape == ():
3927+
values = mb.expand_dims(x=values, axes=[0])
3928+
if values.rank == 1 and values.shape[0] == 1:
3929+
reps = value_at(mb.shape(x=indices), 0)
3930+
reps = mb.expand_dims(x=reps, axes=[0])
3931+
values = mb.tile(x=values, reps=reps)
3932+
elif types.is_int(indices_type):
3933+
# indices
37833934
if len(indices) > 1:
37843935
indices = mb.stack(values=indices, axis=indices[0].rank)
37853936
else:
37863937
indices = mb.expand_dims(x=indices[0], axes=[-1])
3787-
3788-
if len(values.shape) == 0:
3789-
values = mb.expand_dims(x=values, axes=[0])
3790-
3791-
if values.rank == 1 and values.shape[0] == 1:
3792-
reps = value_at(mb.shape(x=indices), 0)
3793-
reps = mb.expand_dims(x=reps, axes=[0])
3794-
values = mb.tile(x=values, reps=reps)
3938+
# values
3939+
expected_values_shape = indices.shape[:-1] + x.shape[indices.shape[-1] :]
3940+
if values.shape != expected_values_shape:
3941+
values = _broadcast(values.name + "_broadcasted", values, expected_values_shape)
3942+
else:
3943+
raise ValueError(f"Only bool and int index handled yet, but got {indices_type}")
37953944

37963945
if is_current_opset_version_compatible_with(target.iOS17):
37973946
# IOS17 `scatter_nd` behaviour is undefined for negative indices.
@@ -5598,7 +5747,20 @@ def std(context, node):
55985747
@register_torch_op
55995748
def copy(context, node):
56005749
inputs = _get_inputs(context, node, expected=[2, 3])
5601-
context.add(mb.identity(x=inputs[0], name=node.name))
5750+
assert (
5751+
context.frontend != TorchFrontend.TORCHSCRIPT
5752+
), (
5753+
"In torch script frontend, by graph pass `generate_tensor_assignment_ops`, "
5754+
"`torch.copy_` should have been replaced with `_internal_op_tensor_inplace_copy`"
5755+
)
5756+
if context.frontend == TorchFrontend.EXIR:
5757+
src = inputs[1]
5758+
if inputs[0].shape != src.shape:
5759+
_, src = _broadcast_tensors(inputs[: 2])
5760+
result = mb.identity(x=src, name=node.name)
5761+
else:
5762+
raise ValueError(f"Invalid PyTorch frontend {context.frontend}")
5763+
context.add(result)
56025764

56035765

56045766
@register_torch_op

0 commit comments

Comments
 (0)