From 0ee872736838e0c36e286519b89a602321829031 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 25 Jul 2023 18:08:33 +0000 Subject: [PATCH] NVM --- .../function_libs/torch_lib/ops/core.py | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 651d0c63e0..a7448bf9b1 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5693,7 +5693,7 @@ def aten_rnn_tanh_cell( @torch_op("aten::roll", trace_only=True) -def aten_roll(self: TTensor, shifts: INT64, dims: Optional[Sequence[int]] = None) -> TTensor: +def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int]) -> TTensor: """roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor""" self_rank = len(self.shape) @@ -5702,27 +5702,19 @@ def aten_roll(self: TTensor, shifts: INT64, dims: Optional[Sequence[int]] = None elif self.shape[0] == 0: # empty tensor return self else: - if dims is None: - return _aten_roll_shift_no_dim_onnx(self, shifts) - elif isinstance(shifts, int) and isinstance(dims, int): - return _aten_roll_shift_and_dim_onnx(self, shifts, dims) - else: # Below condition was skipped because we cannot handle it in OnnxScript - assert len(shifts) == len(dims) - result = self - for i in range(len(shifts)): # pylint: disable=consider-using-enumerate - shift = op.Gather(shifts, i, axis=0) - dim = dims[i] - result = _aten_roll_shift_and_dim_onnx(result, shift, dim) - return result + if not dims: + return _aten_roll_no_dim_onnx(self, shifts) + else: + return _aten_roll_onnx(self, shifts, dims) @torch_op("aten::roll", private=True) -def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: +def _aten_roll_no_dim_onnx(self: TTensor, shifts: INT64) -> TTensor: neg_1 = op.Constant(value_ints=[-1]) # flatten the self tensor self_flatten = op.Reshape(self, neg_1) # Compute slice length - shift_tensor = op.Reshape(shift, neg_1) + shift_tensor = op.Reshape(shifts, neg_1) if shift_tensor < 0: slice_length = -shift_tensor else: @@ -5737,18 +5729,27 @@ def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor: @torch_op("aten::roll", private=True) -def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor: - neg_1 = op.Constant(value_ints=[-1]) - dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1) - shift_tensor = op.Reshape(shift, neg_1) - if shift_tensor < 0: - slice_length = -shift_tensor - else: - slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor - # from [A,B,C,D,E] -> [E,A,B,C,D], [E] is prefix, [A,B,C,D] is suffix - suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) - prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) - result = op.Concat(prefix, suffix, axis=dim) +def _aten_roll_onnx(self: TTensor, shifts: INT64, dims: Sequence[int]) -> TTensor: + result = self + + for i in range(op.Size(shifts)): # pylint: disable=consider-using-enumerate + shift = op.Gather(shifts, i, axis=0) + dim = dims[i] + + # Shift dimension i + neg_1 = op.Constant(value_ints=[-1]) + dim_tensor = op.Reshape(dim, neg_1) + shift_tensor = op.Reshape(shift, neg_1) + if shift_tensor < 0: + slice_length = -shift_tensor + else: + slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor + # from [A,B,C,D,E] -> [E,A,B,C,D], [E] is prefix, [A,B,C,D] is suffix + suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor) + prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor) + # Concat requires a static axis + result = op.Concat(prefix, suffix, axis=dim) + return result