@@ -5692,12 +5692,66 @@ def aten_rnn_tanh_cell(
5692
5692
raise NotImplementedError ()
5693
5693
5694
5694
5695
- def aten_roll (
5696
- self : TensorType , shifts : Sequence [int ], dims : Optional [Sequence [int ]] = None
5697
- ) -> TensorType :
5695
+ @torch_op ("aten::roll" , trace_only = True )
5696
+ def aten_roll (self : TTensor , shifts : INT64 , dims : Sequence [int ] = ()) -> TTensor :
5698
5697
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
5699
5698
5700
- raise NotImplementedError ()
5699
+ self_rank = len (self .shape )
5700
+ if self_rank == 0 :
5701
+ return self
5702
+ elif self .shape [0 ] == 0 : # empty tensor
5703
+ return self
5704
+ else :
5705
+ if isinstance (dims , tuple ) and len (dims ) == 0 : # Empty list
5706
+ # assert isinstance(shifts, int)
5707
+ return _aten_roll_shift_no_dim_onnx (self , shifts )
5708
+ else :
5709
+ # assert len(shifts) == len(dims), but shifts is a tensor, dims is a list
5710
+ result = self
5711
+ for i in range (len (shifts )): # pylint: disable=consider-using-enumerate
5712
+ shift = op .Gather (shifts , i , axis = 0 )
5713
+ dim = dims [i ]
5714
+ result = _aten_roll_shift_and_dim_onnx (result , shift , dim )
5715
+ return result
5716
+
5717
+
5718
+ @torch_op ("aten::roll" , private = True )
5719
+ def _aten_roll_shift_no_dim_onnx (self : TTensor , shift : INT64 ) -> TTensor :
5720
+ neg_1 = op .Constant (value_ints = [- 1 ])
5721
+ # flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D]
5722
+ self_flatten = op .Reshape (self , neg_1 )
5723
+ # Compute slice length
5724
+ shift_tensor = op .Reshape (shift , neg_1 )
5725
+ if shift_tensor < 0 :
5726
+ # For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end
5727
+ slice_length = - shift_tensor
5728
+ else :
5729
+ # For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end
5730
+ # The effect equals to move [D] to the beginning
5731
+ slice_length = op .Size (self_flatten ) - shift_tensor
5732
+ # Get second part of the tensor, e.g. [A,B,C]
5733
+ suffix = op .Slice (self_flatten , op .Constant (value_ints = [0 ]), slice_length )
5734
+ # Get first part of the tensor, e.g. [D]
5735
+ prefix = op .Slice (self_flatten , slice_length , op .Reshape (op .Size (self_flatten ), neg_1 ))
5736
+ # Concat first+second together, e.g. [D,A,B,C]
5737
+ result = op .Concat (prefix , suffix , axis = 0 )
5738
+ return op .Reshape (result , op .Shape (self ))
5739
+
5740
+
5741
+ @torch_op ("aten::roll" , private = True )
5742
+ def _aten_roll_shift_and_dim_onnx (self : TTensor , shift : INT64 , dim : int ) -> TTensor :
5743
+ neg_1 = op .Constant (value_ints = [- 1 ])
5744
+ dim_tensor = op .Reshape (op .Constant (value_int = dim ), neg_1 )
5745
+ shift_tensor = op .Reshape (shift , neg_1 )
5746
+ if shift_tensor < 0 :
5747
+ slice_length = - shift_tensor
5748
+ else :
5749
+ slice_length = op .Gather (op .Shape (self ), dim_tensor , axis = 0 ) - shift_tensor
5750
+ # from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix
5751
+ suffix = op .Slice (self , op .Constant (value_ints = [0 ]), slice_length , axes = dim_tensor )
5752
+ prefix = op .Slice (self , slice_length , op .Reshape (op .Size (self ), neg_1 ), axes = dim_tensor )
5753
+ result = op .Concat (prefix , suffix , axis = dim )
5754
+ return result
5701
5755
5702
5756
5703
5757
def aten_rot90 (self : TensorType , k : int = 1 , dims : Sequence [int ] = (0 , 1 )) -> TensorType :
0 commit comments