@@ -1676,6 +1676,14 @@ def view(context, node):
16761676 x = inputs [0 ]
16771677 shape = inputs [1 ]
16781678
1679+ if np .prod (shape .shape ) == 0 :
1680+ # Reshape to empty shape (works only for scalar) is a no op
1681+ assert (
1682+ np .prod (x .shape ) <= 1
1683+ ), "Reshape to empty shape works only for scalar and single-element tensor"
1684+ context .add (mb .identity (x = x , name = node .name ))
1685+ return
1686+
16791687 if isinstance (shape , ListVar ):
16801688 length = mb .list_length (ls = shape )
16811689 indices = mb .range_1d (start = 0 , end = length , step = 1 )
@@ -3759,39 +3767,180 @@ def _internal_op_tensor_inplace_fill(context, node):
37593767
37603768
37613769@register_torch_op
3762- def index_put (context , node ):
3770+ def select_scatter (context , node ):
37633771 inputs = _get_inputs (context , node , expected = 4 )
37643772 x = inputs [0 ]
3773+ updates = inputs [1 ]
3774+ dim = inputs [2 ].val
3775+ if dim is None :
3776+ raise ValueError ("Only compile time known dim supported yet" )
3777+ index = inputs [3 ]
3778+
3779+ # mb.torch_tensor_assign handles multi-dim slicing
3780+ # so we need to create slice specifications for all other dimensions
3781+ begin = [0 ] * x .rank
3782+ begin [dim ] = index
3783+ begin = mb .concat (values = begin , axis = 0 )
3784+ end = x .shape
3785+ # and squeeze dim to do pure indexing on it
3786+ squeeze_mask = [False ] * x .rank
3787+ squeeze_mask [dim ] = True
3788+
3789+ updated_x = _translate_torch_tensor_assign (
3790+ x = x ,
3791+ updates = updates ,
3792+ begin = begin ,
3793+ end = end ,
3794+ stride = None ,
3795+ begin_mask = None ,
3796+ end_mask = None ,
3797+ squeeze_mask = squeeze_mask ,
3798+ name = node .name ,
3799+ )
3800+ context .add (updated_x )
3801+
3802+
3803+ @register_torch_op
3804+ def slice_scatter (context , node ):
3805+ inputs = _get_inputs (context , node , min_expected = 2 )
3806+ x , updates = promote_input_dtypes (inputs [0 :2 ])
3807+ dim = 0 if len (inputs ) <= 2 else inputs [2 ].val
3808+ if dim is None :
3809+ raise ValueError ("Only compile time known dim supported yet" )
3810+ start = 0 if len (inputs ) <= 3 else inputs [3 ]
3811+ end = x .shape [dim ] if len (inputs ) <= 4 else mb .minimum (x = inputs [4 ], y = x .shape [dim ])
3812+ step = 1 if len (inputs ) <= 5 else inputs [5 ]
3813+
3814+ assert dim is not None , "slice dim must be known at compile time"
3815+ assert 0 <= dim and dim < x .rank
3816+
3817+ # mb.torch_tensor_assign handles multi-dim slicing
3818+ # so we need to pad start, end, step from scalar to x.rank
3819+ starts = [0 ] * x .rank
3820+ starts [dim ] = start
3821+ starts = mb .concat (values = starts , axis = 0 )
3822+ ends = list (x .shape )
3823+ ends [dim ] = end
3824+ ends = mb .concat (values = ends , axis = 0 )
3825+ steps = [1 ] * x .rank
3826+ steps [dim ] = step
3827+ steps = mb .concat (values = steps , axis = 0 )
3828+
3829+ updated_x = _translate_torch_tensor_assign (
3830+ x = x ,
3831+ updates = updates ,
3832+ begin = starts ,
3833+ end = ends ,
3834+ stride = steps ,
3835+ begin_mask = None ,
3836+ end_mask = None ,
3837+ squeeze_mask = None ,
3838+ name = node .name ,
3839+ )
3840+ context .add (updated_x )
3841+
3842+
3843+ @register_torch_op
3844+ def index_put (context , node ):
3845+ inputs = _get_inputs (context , node , min_expected = 3 )
3846+ x = inputs [0 ]
37653847 indices = inputs [1 ]
37663848 values = inputs [2 ]
3767- accumulate = inputs [3 ].val
3768- rank = x .rank
3849+ accumulate = False if len (inputs ) < 4 else inputs [3 ].val
37693850 mode = "add" if accumulate else "update"
37703851
3771- indices_type = indices [0 ].sym_type .get_primitive ()
3852+ assert isinstance (indices , list ), "indices must be a list of tensors"
3853+ # Usually indices is a list of non-None tensors, so we stack them and feed to mb.scatter_nd
3854+ # However, when there exists a whole slice (i.e. :), that index is represented as None
3855+ if any (map (lambda index : index is None , indices )):
3856+ # We have 2 ways to translate such torch.index_put, both have pros and cons
3857+ # 1. mb.scatter_nd
3858+ # * pro: can handle accumulate or update
3859+ # * con: can only have whole slice at last dimensions
3860+ # 2. mb.torch_tensor_assign
3861+ # * pro: can have whole slice at arbitrary dimension
3862+ # * con: can only handle update
3863+ # Here we use mb.torch_tensor_assign
3864+ # TODO: explore how can we cover as many torch.index_put cases as possible
3865+ if accumulate :
3866+ raise NotImplementedError (
3867+ "If there existed any whole slice (e.g. : in x[:, 0]), "
3868+ "only torch.index_put(..., accumulate=False) handled yet"
3869+ )
3870+
3871+ begin = [0 ] * x .rank
3872+ end = list (x .shape )
3873+ stride = [1 ] * x .rank
3874+ begin_mask = [True ] * x .rank
3875+ end_mask = [True ] * x .rank
3876+ # note: in torch slice, an indexed dim becomes size 1, rather than squeezed, e.g.
3877+ # x = torch.zeros((2, 3))
3878+ # y = x[:, 1]
3879+ # we will get y.shape as (2, 1)
3880+ is_dim_unity = [False ] * x .rank
3881+ for dim , index in enumerate (indices ):
3882+ if index is not None :
3883+ if len (index .shape ) > 0 :
3884+ index = mb .squeeze (x = index )
3885+ begin [dim ] = index
3886+ end [dim ] = mb .add (x = index , y = 1 )
3887+ begin_mask [dim ] = False
3888+ end_mask [dim ] = False
3889+ is_dim_unity [dim ] = True
3890+ begin = mb .concat (values = begin , axis = 0 )
3891+ end = mb .concat (values = end , axis = 0 )
3892+
3893+ expected_values_shape = []
3894+ for dim in range (x .rank ):
3895+ expected_values_shape .append (1 if is_dim_unity [dim ] else x .shape [dim ])
3896+ expected_values_shape = tuple (expected_values_shape )
3897+
3898+ if values .shape != expected_values_shape :
3899+ values = _broadcast (values .name + "_broadcasted" , values , expected_values_shape )
3900+
3901+ updated_x = _translate_torch_tensor_assign (
3902+ x = x ,
3903+ updates = values ,
3904+ begin = begin ,
3905+ end = end ,
3906+ stride = stride ,
3907+ begin_mask = begin_mask ,
3908+ end_mask = end_mask ,
3909+ squeeze_mask = [False ] * x .rank ,
3910+ name = node .name ,
3911+ )
3912+ context .add (updated_x )
3913+ return
37723914
3915+ indices_type = indices [0 ].sym_type .get_primitive ()
37733916 if types .is_bool (indices_type ):
3917+ # indices
37743918 assert len (indices ) == 1 , "Unsupported index_put_ usage."
37753919 indices = indices [0 ]
37763920 assert (
37773921 indices .shape == x .shape
37783922 ), "indices shape must equal to input shape for index put operation."
37793923 indices = mb .cast (x = indices , dtype = "int32" )
37803924 indices = mb .non_zero (x = indices )
3781-
3782- if types .is_int (indices_type ):
3925+ # values
3926+ if values .shape == ():
3927+ values = mb .expand_dims (x = values , axes = [0 ])
3928+ if values .rank == 1 and values .shape [0 ] == 1 :
3929+ reps = value_at (mb .shape (x = indices ), 0 )
3930+ reps = mb .expand_dims (x = reps , axes = [0 ])
3931+ values = mb .tile (x = values , reps = reps )
3932+ elif types .is_int (indices_type ):
3933+ # indices
37833934 if len (indices ) > 1 :
37843935 indices = mb .stack (values = indices , axis = indices [0 ].rank )
37853936 else :
37863937 indices = mb .expand_dims (x = indices [0 ], axes = [- 1 ])
3787-
3788- if len (values .shape ) == 0 :
3789- values = mb .expand_dims (x = values , axes = [0 ])
3790-
3791- if values .rank == 1 and values .shape [0 ] == 1 :
3792- reps = value_at (mb .shape (x = indices ), 0 )
3793- reps = mb .expand_dims (x = reps , axes = [0 ])
3794- values = mb .tile (x = values , reps = reps )
3938+ # values
3939+ expected_values_shape = indices .shape [:- 1 ] + x .shape [indices .shape [- 1 ] :]
3940+ if values .shape != expected_values_shape :
3941+ values = _broadcast (values .name + "_broadcasted" , values , expected_values_shape )
3942+ else :
3943+ raise ValueError (f"Only bool and int index handled yet, but got { indices_type } " )
37953944
37963945 if is_current_opset_version_compatible_with (target .iOS17 ):
37973946 # IOS17 `scatter_nd` behaviour is undefined for negative indices.
@@ -5598,7 +5747,20 @@ def std(context, node):
55985747@register_torch_op
55995748def copy (context , node ):
56005749 inputs = _get_inputs (context , node , expected = [2 , 3 ])
5601- context .add (mb .identity (x = inputs [0 ], name = node .name ))
5750+ assert (
5751+ context .frontend != TorchFrontend .TORCHSCRIPT
5752+ ), (
5753+ "In torch script frontend, by graph pass `generate_tensor_assignment_ops`, "
5754+ "`torch.copy_` should have been replaced with `_internal_op_tensor_inplace_copy`"
5755+ )
5756+ if context .frontend == TorchFrontend .EXIR :
5757+ src = inputs [1 ]
5758+ if inputs [0 ].shape != src .shape :
5759+ _ , src = _broadcast_tensors (inputs [: 2 ])
5760+ result = mb .identity (x = src , name = node .name )
5761+ else :
5762+ raise ValueError (f"Invalid PyTorch frontend { context .frontend } " )
5763+ context .add (result )
56025764
56035765
56045766@register_torch_op
0 commit comments