Skip to content

Commit 6e19c81

Browse files
committed
add integer dims in test
1 parent d0f7449 commit 6e19c81

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

py/torch_tensorrt/dynamo/conversion/impl/permutation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional, Sequence, Union
22

33
import tensorrt as trt
44
from torch.fx.node import Target
@@ -10,7 +10,7 @@
1010
get_positive_dim,
1111
)
1212
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
13-
from torch_tensorrt.fx.types import Shape, TRTTensor
13+
from torch_tensorrt.fx.types import TRTTensor
1414

1515

1616
def permute(
@@ -40,10 +40,15 @@ def roll(
4040
source_ir: Optional[SourceIR],
4141
name: str,
4242
input: TRTTensor,
43-
shifts: Shape,
44-
dims: Shape,
43+
shifts: Union[int, Sequence[int]],
44+
dims: Union[int, Sequence[int]],
4545
) -> TRTTensor:
4646
shape = input.shape
47+
if isinstance(shifts, int):
48+
shifts = [shifts]
49+
if isinstance(dims, int):
50+
dims = [dims]
51+
4752
if dims != []:
4853
rank = len(shape)
4954
start = [0] * rank

tests/py/dynamo/conversion/test_roll_aten.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
class TestRollConverter(DispatchTestCase):
1111
@parameterized.expand(
1212
[
13+
((4,), (2,), 0),
1314
((4,), [2], [0]),
1415
((4,), [3], [0]),
1516
((4,), [-3, 2], [0, 0]),
@@ -29,7 +30,7 @@ class TestRollConverter(DispatchTestCase):
2930
),
3031
]
3132
)
32-
def test_roll_list(self, shape, shifts, dims):
33+
def test_roll(self, shape, shifts, dims):
3334
class Roll(nn.Module):
3435
def forward(self, x):
3536
return torch.ops.aten.roll.default(x, shifts, dims)

0 commit comments

Comments
 (0)