@@ -3782,12 +3782,7 @@ def select_scatter(context, node):
37823782 begin [dim ] = index
37833783 begin = mb .concat (values = begin , axis = 0 )
37843784 end = x .shape
3785- stride = [1 ] * x .rank
3786-
3787- begin_mask = [True ] * x .rank
3788- if index .val not in (0 , - x .rank ):
3789- begin_mask [dim ] = False
3790- end_mask = [True ] * x .rank
3785+ # and squeeze dim to do pure indexing on it
37913786 squeeze_mask = [False ] * x .rank
37923787 squeeze_mask [dim ] = True
37933788
@@ -3796,9 +3791,9 @@ def select_scatter(context, node):
37963791 updates = updates ,
37973792 begin = begin ,
37983793 end = end ,
3799- stride = stride ,
3800- begin_mask = begin_mask ,
3801- end_mask = end_mask ,
3794+ stride = None ,
3795+ begin_mask = None ,
3796+ end_mask = None ,
38023797 squeeze_mask = squeeze_mask ,
38033798 name = node .name ,
38043799 )
@@ -3810,6 +3805,8 @@ def slice_scatter(context, node):
38103805 inputs = _get_inputs (context , node , min_expected = 2 )
38113806 x , updates = promote_input_dtypes (inputs [0 :2 ])
38123807 dim = 0 if len (inputs ) <= 2 else inputs [2 ].val
3808+ if dim is None :
3809+ raise ValueError ("Only compile time known dim supported yet" )
38133810 start = 0 if len (inputs ) <= 3 else inputs [3 ]
38143811 end = x .shape [dim ] if len (inputs ) <= 4 else mb .minimum (x = inputs [4 ], y = x .shape [dim ])
38153812 step = 1 if len (inputs ) <= 5 else inputs [5 ]
@@ -3829,24 +3826,15 @@ def slice_scatter(context, node):
38293826 steps [dim ] = step
38303827 steps = mb .concat (values = steps , axis = 0 )
38313828
3832- # mb.torch_tensor_assign also have masks
3833- begin_mask = [True ] * x .rank
3834- if start .val not in (0 , - x .rank ):
3835- begin_mask [dim ] = False
3836- end_mask = [True ] * x .rank
3837- if end .val is None or end .val < x .shape [dim ]:
3838- end_mask [dim ] = False
3839- squeeze_mask = [False ] * x .rank
3840-
38413829 updated_x = _translate_torch_tensor_assign (
38423830 x = x ,
38433831 updates = updates ,
38443832 begin = starts ,
38453833 end = ends ,
38463834 stride = steps ,
3847- begin_mask = begin_mask ,
3848- end_mask = end_mask ,
3849- squeeze_mask = squeeze_mask ,
3835+ begin_mask = None ,
3836+ end_mask = None ,
3837+ squeeze_mask = None ,
38503838 name = node .name ,
38513839 )
38523840 context .add (updated_x )
0 commit comments