Skip to content

Commit d0f7449

Browse files
committed
feat: support aten.roll dynamo converter
1 parent 4b608f0 commit d0f7449

File tree

3 files changed

+126
-2
lines changed

3 files changed

+126
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2649,3 +2649,27 @@ def aten_ops_flip(
26492649
args[0],
26502650
args[1],
26512651
)
2652+
2653+
2654+
@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
2655+
@enforce_tensor_types(
2656+
{
2657+
0: (TRTTensor,),
2658+
}
2659+
)
2660+
def aten_ops_roll(
2661+
ctx: ConversionContext,
2662+
target: Target,
2663+
args: Tuple[Argument, ...],
2664+
kwargs: Dict[str, Argument],
2665+
name: str,
2666+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2667+
return impl.permutation.roll(
2668+
ctx,
2669+
target,
2670+
SourceIR.ATEN,
2671+
name,
2672+
args[0],
2673+
args[1],
2674+
args_bounds_check(args, 2, []),
2675+
)

py/torch_tensorrt/dynamo/conversion/impl/permutation.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from typing import Optional, Sequence
22

3+
import tensorrt as trt
34
from torch.fx.node import Target
45
from torch_tensorrt.dynamo._SourceIR import SourceIR
6+
from torch_tensorrt.dynamo.conversion import impl
57
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
flatten_dims,
10+
get_positive_dim,
11+
)
712
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
8-
from torch_tensorrt.fx.types import TRTTensor
13+
from torch_tensorrt.fx.types import Shape, TRTTensor
914

1015

1116
def permute(
@@ -27,3 +32,56 @@ def permute(
2732
layer.second_transpose = tuple(permutation)
2833
set_layer_name(layer, target, name, source_ir)
2934
return layer.get_output(0)
35+
36+
37+
def roll(
38+
ctx: ConversionContext,
39+
target: Target,
40+
source_ir: Optional[SourceIR],
41+
name: str,
42+
input: TRTTensor,
43+
shifts: Shape,
44+
dims: Shape,
45+
) -> TRTTensor:
46+
shape = input.shape
47+
if dims != []:
48+
rank = len(shape)
49+
start = [0] * rank
50+
stride = [1] * rank
51+
for i in range(len(dims)):
52+
d = dims[i]
53+
s = shifts[i]
54+
start[d] += get_positive_dim(
55+
-s, shape[d]
56+
) # in case that dims has multiple same dim
57+
58+
layer = ctx.net.add_slice(
59+
input,
60+
start=start,
61+
shape=shape,
62+
stride=stride,
63+
)
64+
layer.mode = trt.SliceMode.WRAP
65+
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
66+
return layer.get_output(0)
67+
68+
else:
69+
flatten_shape = flatten_dims(input, 0, -1)
70+
output = impl.shuffle.reshape(
71+
ctx, target, source_ir, f"{name}_reshape", input, flatten_shape
72+
)
73+
start = [get_positive_dim(-shifts[0], output.shape[0])]
74+
stride = [1]
75+
layer = ctx.net.add_slice(
76+
output,
77+
start=start,
78+
shape=flatten_shape,
79+
stride=stride,
80+
)
81+
layer.mode = trt.SliceMode.WRAP
82+
set_layer_name(layer, target, f"{name}_slice_wrap", source_ir)
83+
output = layer.get_output(0)
84+
output = impl.shuffle.reshape(
85+
ctx, target, source_ir, f"{name}_reshape_back", output, shape
86+
)
87+
return output
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestRollConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
((4,), [2], [0]),
14+
((4,), [3], [0]),
15+
((4,), [-3, 2], [0, 0]),
16+
((4,), [-2], []),
17+
((4, 2), [2, 1], [0, 1]),
18+
((3, 3), [2, 1], [1, 1]),
19+
((4, 2), [2, -1], [-2, -1]),
20+
((4, 2), [4], []),
21+
((3, 4, 2), [1, 0, 2], [2, 0, -2]),
22+
((3, 4, 2), [1, -0, 2], [1, 1, 1]),
23+
(
24+
(3, 4, 2),
25+
[
26+
5,
27+
],
28+
[],
29+
),
30+
]
31+
)
32+
def test_roll_list(self, shape, shifts, dims):
33+
class Roll(nn.Module):
34+
def forward(self, x):
35+
return torch.ops.aten.roll.default(x, shifts, dims)
36+
37+
inputs = [torch.randn(shape)]
38+
self.run_test(Roll(), inputs)
39+
40+
41+
if __name__ == "__main__":
42+
run_tests()

0 commit comments

Comments
 (0)