Skip to content

Commit d06e498

Browse files
authored
Merge branch 'main' into rama/if_loop
2 parents 4c8da73 + d77f001 commit d06e498

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5692,12 +5692,66 @@ def aten_rnn_tanh_cell(
56925692
raise NotImplementedError()
56935693

56945694

5695-
def aten_roll(
5696-
self: TensorType, shifts: Sequence[int], dims: Optional[Sequence[int]] = None
5697-
) -> TensorType:
5695+
@torch_op("aten::roll", trace_only=True)
5696+
def aten_roll(self: TTensor, shifts: INT64, dims: Sequence[int] = ()) -> TTensor:
56985697
"""roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"""
56995698

5700-
raise NotImplementedError()
5699+
self_rank = len(self.shape)
5700+
if self_rank == 0:
5701+
return self
5702+
elif self.shape[0] == 0: # empty tensor
5703+
return self
5704+
else:
5705+
if isinstance(dims, tuple) and len(dims) == 0: # Empty list
5706+
# assert isinstance(shifts, int)
5707+
return _aten_roll_shift_no_dim_onnx(self, shifts)
5708+
else:
5709+
# assert len(shifts) == len(dims), but shifts is a tensor, dims is a list
5710+
result = self
5711+
for i in range(len(shifts)): # pylint: disable=consider-using-enumerate
5712+
shift = op.Gather(shifts, i, axis=0)
5713+
dim = dims[i]
5714+
result = _aten_roll_shift_and_dim_onnx(result, shift, dim)
5715+
return result
5716+
5717+
5718+
@torch_op("aten::roll", private=True)
5719+
def _aten_roll_shift_no_dim_onnx(self: TTensor, shift: INT64) -> TTensor:
5720+
neg_1 = op.Constant(value_ints=[-1])
5721+
# flatten the self tensor: from [[A,B],[C,D]] to [A,B,C,D]
5722+
self_flatten = op.Reshape(self, neg_1)
5723+
# Compute slice length
5724+
shift_tensor = op.Reshape(shift, neg_1)
5725+
if shift_tensor < 0:
5726+
# For [A,B,C,D], if shift is -1, slice_length = -(-1) = 1, means move [A] to the end
5727+
slice_length = -shift_tensor
5728+
else:
5729+
# For [A,B,C,D], if shift is 1, slice_length = 4 - 1 = 3, means move [A,B,C] to the end
5730+
# The effect equals to move [D] to the beginning
5731+
slice_length = op.Size(self_flatten) - shift_tensor
5732+
# Get second part of the tensor, e.g. [A,B,C]
5733+
suffix = op.Slice(self_flatten, op.Constant(value_ints=[0]), slice_length)
5734+
# Get first part of the tensor, e.g. [D]
5735+
prefix = op.Slice(self_flatten, slice_length, op.Reshape(op.Size(self_flatten), neg_1))
5736+
# Concat first+second together, e.g. [D,A,B,C]
5737+
result = op.Concat(prefix, suffix, axis=0)
5738+
return op.Reshape(result, op.Shape(self))
5739+
5740+
5741+
@torch_op("aten::roll", private=True)
5742+
def _aten_roll_shift_and_dim_onnx(self: TTensor, shift: INT64, dim: int) -> TTensor:
5743+
neg_1 = op.Constant(value_ints=[-1])
5744+
dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1)
5745+
shift_tensor = op.Reshape(shift, neg_1)
5746+
if shift_tensor < 0:
5747+
slice_length = -shift_tensor
5748+
else:
5749+
slice_length = op.Gather(op.Shape(self), dim_tensor, axis=0) - shift_tensor
5750+
# from [A,B,C,D] -> [D,A,B,C], [D] is prefix, [A,B,C] is suffix
5751+
suffix = op.Slice(self, op.Constant(value_ints=[0]), slice_length, axes=dim_tensor)
5752+
prefix = op.Slice(self, slice_length, op.Reshape(op.Size(self), neg_1), axes=dim_tensor)
5753+
result = op.Concat(prefix, suffix, axis=dim)
5754+
return result
57015755

57025756

57035757
def aten_rot90(self: TensorType, k: int = 1, dims: Sequence[int] = (0, 1)) -> TensorType:

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,24 @@ def _replication_pad3d_input_wrangler(
359359
return args, kwargs
360360

361361

362+
def _roll_input_wrangler(
363+
args: list[Any], kwargs: dict[str, Any]
364+
) -> tuple[list[Any], dict[str, Any]]:
365+
if len(args) >= 3:
366+
if isinstance(args[2], np.ndarray): # convert dims to list[int]
367+
# Change dims from args to kwargs to keep tuple/list type
368+
dims = args.pop(2)
369+
kwargs["dims"] = dims.tolist()
370+
elif isinstance(args[2], int): # convert dims to list[int]
371+
dims = args.pop(2)
372+
kwargs["dims"] = []
373+
kwargs["dims"].append(dims)
374+
if len(args) >= 2:
375+
if isinstance(args[1], int): # convert shift to tensor
376+
args[1] = np.array([args[1]], dtype=np.int64)
377+
return args, kwargs
378+
379+
362380
def _scatter_reduce_input_wrangler(
363381
args: list[Any], kwargs: dict[str, Any]
364382
) -> tuple[list[Any], dict[str, Any]]:
@@ -1630,6 +1648,12 @@ def _where_input_wrangler(
16301648
reason="fixme: the scale_factor tests",
16311649
),
16321650
TorchLibOpInfo("ones_like", core_ops.aten_ones_like, trace_only=True),
1651+
TorchLibOpInfo(
1652+
"roll",
1653+
core_ops.aten_roll,
1654+
trace_only=True,
1655+
input_wrangler=_roll_input_wrangler,
1656+
),
16331657
TorchLibOpInfo(
16341658
"scatter_reduce",
16351659
core_ops.aten_scatter_reduce,

0 commit comments

Comments
 (0)