Skip to content

Commit 25a0173

Browse files
author
yifan_shen3
committed
address review comment: remove unnecessary mask
1 parent e0f2377 commit 25a0173

File tree

1 file changed

+9
-21
lines changed
  • coremltools/converters/mil/frontend/torch

1 file changed

+9
-21
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,12 +3782,7 @@ def select_scatter(context, node):
37823782
begin[dim] = index
37833783
begin = mb.concat(values=begin, axis=0)
37843784
end = x.shape
3785-
stride = [1] * x.rank
3786-
3787-
begin_mask = [True] * x.rank
3788-
if index.val not in (0, -x.rank):
3789-
begin_mask[dim] = False
3790-
end_mask = [True] * x.rank
3785+
# and squeeze dim to do pure indexing on it
37913786
squeeze_mask = [False] * x.rank
37923787
squeeze_mask[dim] = True
37933788

@@ -3796,9 +3791,9 @@ def select_scatter(context, node):
37963791
updates=updates,
37973792
begin=begin,
37983793
end=end,
3799-
stride=stride,
3800-
begin_mask=begin_mask,
3801-
end_mask=end_mask,
3794+
stride=None,
3795+
begin_mask=None,
3796+
end_mask=None,
38023797
squeeze_mask=squeeze_mask,
38033798
name=node.name,
38043799
)
@@ -3810,6 +3805,8 @@ def slice_scatter(context, node):
38103805
inputs = _get_inputs(context, node, min_expected=2)
38113806
x, updates = promote_input_dtypes(inputs[0:2])
38123807
dim = 0 if len(inputs) <= 2 else inputs[2].val
3808+
if dim is None:
3809+
raise ValueError("Only compile time known dim supported yet")
38133810
start = 0 if len(inputs) <= 3 else inputs[3]
38143811
end = x.shape[dim] if len(inputs) <= 4 else mb.minimum(x=inputs[4], y=x.shape[dim])
38153812
step = 1 if len(inputs) <= 5 else inputs[5]
@@ -3829,24 +3826,15 @@ def slice_scatter(context, node):
38293826
steps[dim] = step
38303827
steps = mb.concat(values=steps, axis=0)
38313828

3832-
# mb.torch_tensor_assign also have masks
3833-
begin_mask = [True] * x.rank
3834-
if start.val not in (0, -x.rank):
3835-
begin_mask[dim] = False
3836-
end_mask = [True] * x.rank
3837-
if end.val is None or end.val < x.shape[dim]:
3838-
end_mask[dim] = False
3839-
squeeze_mask = [False] * x.rank
3840-
38413829
updated_x = _translate_torch_tensor_assign(
38423830
x=x,
38433831
updates=updates,
38443832
begin=starts,
38453833
end=ends,
38463834
stride=steps,
3847-
begin_mask=begin_mask,
3848-
end_mask=end_mask,
3849-
squeeze_mask=squeeze_mask,
3835+
begin_mask=None,
3836+
end_mask=None,
3837+
squeeze_mask=None,
38503838
name=node.name,
38513839
)
38523840
context.add(updated_x)

0 commit comments

Comments
 (0)