Skip to content

Commit bf0bc00

Browse files
authored
feat: add dynamic support for eq/ne/lt/le (#2979)
1 parent e6ab7f8 commit bf0bc00

File tree

5 files changed

+348
-8
lines changed

5 files changed

+348
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2276,8 +2276,8 @@ def aten_ops_bitwise_not(
22762276
)
22772277

22782278

2279-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor)
2280-
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar)
2279+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor, supports_dynamic_shapes=True)
2280+
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar, supports_dynamic_shapes=True)
22812281
@enforce_tensor_types(
22822282
{
22832283
0: (TRTTensor,),
@@ -2300,8 +2300,8 @@ def aten_ops_eq(
23002300
)
23012301

23022302

2303-
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor)
2304-
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar)
2303+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Tensor, supports_dynamic_shapes=True)
2304+
@dynamo_tensorrt_converter(torch.ops.aten.ne.Scalar, supports_dynamic_shapes=True)
23052305
@enforce_tensor_types(
23062306
{
23072307
0: (TRTTensor,),
@@ -2372,8 +2372,8 @@ def aten_ops_ge(
23722372
)
23732373

23742374

2375-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor)
2376-
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar)
2375+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Tensor, supports_dynamic_shapes=True)
2376+
@dynamo_tensorrt_converter(torch.ops.aten.lt.Scalar, supports_dynamic_shapes=True)
23772377
@enforce_tensor_types(
23782378
{
23792379
0: (TRTTensor,),
@@ -2396,8 +2396,8 @@ def aten_ops_lt(
23962396
)
23972397

23982398

2399-
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor)
2400-
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar)
2399+
@dynamo_tensorrt_converter(torch.ops.aten.le.Tensor, supports_dynamic_shapes=True)
2400+
@dynamo_tensorrt_converter(torch.ops.aten.le.Scalar, supports_dynamic_shapes=True)
24012401
@enforce_tensor_types(
24022402
{
24032403
0: (TRTTensor,),

tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -61,6 +62,90 @@ def forward(self, lhs_val):
6162
inputs,
6263
)
6364

65+
@parameterized.expand(
66+
[
67+
((1,), (3,), (5,)),
68+
((1, 20), (2, 20), (3, 20)),
69+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
70+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
71+
]
72+
)
73+
def test_eq_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
74+
class eq(nn.Module):
75+
def forward(self, lhs_val, rhs_val):
76+
return torch.ops.aten.eq.Tensor(lhs_val, rhs_val)
77+
78+
input_specs = [
79+
Input(
80+
dtype=torch.float32,
81+
min_shape=min_shape,
82+
opt_shape=opt_shape,
83+
max_shape=max_shape,
84+
),
85+
Input(
86+
dtype=torch.float32,
87+
min_shape=min_shape,
88+
opt_shape=opt_shape,
89+
max_shape=max_shape,
90+
),
91+
]
92+
self.run_test_with_dynamic_shape(
93+
eq(),
94+
input_specs,
95+
)
96+
97+
@parameterized.expand(
98+
[
99+
((1,), (3,), (5,)),
100+
((1, 20), (2, 20), (3, 20)),
101+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
102+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
103+
]
104+
)
105+
def test_eq_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
106+
class eq(nn.Module):
107+
def forward(self, lhs_val):
108+
return torch.ops.aten.eq.Tensor(lhs_val, torch.tensor(1))
109+
110+
input_specs = [
111+
Input(
112+
dtype=torch.int32,
113+
min_shape=min_shape,
114+
opt_shape=opt_shape,
115+
max_shape=max_shape,
116+
),
117+
]
118+
self.run_test_with_dynamic_shape(
119+
eq(),
120+
input_specs,
121+
)
122+
123+
@parameterized.expand(
124+
[
125+
((1,), (3,), (5,)),
126+
((1, 20), (2, 20), (3, 20)),
127+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
128+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
129+
]
130+
)
131+
def test_eq_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
132+
class eq(nn.Module):
133+
def forward(self, lhs_val):
134+
return torch.ops.aten.eq.Scalar(lhs_val, 1.0)
135+
136+
input_specs = [
137+
Input(
138+
dtype=torch.int32,
139+
min_shape=min_shape,
140+
opt_shape=opt_shape,
141+
max_shape=max_shape,
142+
),
143+
]
144+
self.run_test_with_dynamic_shape(
145+
eq(),
146+
input_specs,
147+
)
148+
64149

65150
if __name__ == "__main__":
66151
run_tests()

tests/py/dynamo/conversion/test_le_aten.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -61,6 +62,90 @@ def forward(self, lhs_val):
6162
inputs,
6263
)
6364

65+
@parameterized.expand(
66+
[
67+
((1,), (3,), (5,)),
68+
((1, 20), (2, 20), (3, 20)),
69+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
70+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
71+
]
72+
)
73+
def test_le_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
74+
class le(nn.Module):
75+
def forward(self, lhs_val, rhs_val):
76+
return torch.ops.aten.le.Tensor(lhs_val, rhs_val)
77+
78+
input_specs = [
79+
Input(
80+
dtype=torch.float32,
81+
min_shape=min_shape,
82+
opt_shape=opt_shape,
83+
max_shape=max_shape,
84+
),
85+
Input(
86+
dtype=torch.float32,
87+
min_shape=min_shape,
88+
opt_shape=opt_shape,
89+
max_shape=max_shape,
90+
),
91+
]
92+
self.run_test_with_dynamic_shape(
93+
le(),
94+
input_specs,
95+
)
96+
97+
@parameterized.expand(
98+
[
99+
((1,), (3,), (5,)),
100+
((1, 20), (2, 20), (3, 20)),
101+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
102+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
103+
]
104+
)
105+
def test_le_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
106+
class le(nn.Module):
107+
def forward(self, lhs_val):
108+
return torch.ops.aten.le.Tensor(lhs_val, torch.tensor(1))
109+
110+
input_specs = [
111+
Input(
112+
dtype=torch.int32,
113+
min_shape=min_shape,
114+
opt_shape=opt_shape,
115+
max_shape=max_shape,
116+
),
117+
]
118+
self.run_test_with_dynamic_shape(
119+
le(),
120+
input_specs,
121+
)
122+
123+
@parameterized.expand(
124+
[
125+
((1,), (3,), (5,)),
126+
((1, 20), (2, 20), (3, 20)),
127+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
128+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
129+
]
130+
)
131+
def test_le_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
132+
class le(nn.Module):
133+
def forward(self, lhs_val):
134+
return torch.ops.aten.le.Scalar(lhs_val, 1.0)
135+
136+
input_specs = [
137+
Input(
138+
dtype=torch.int32,
139+
min_shape=min_shape,
140+
opt_shape=opt_shape,
141+
max_shape=max_shape,
142+
),
143+
]
144+
self.run_test_with_dynamic_shape(
145+
le(),
146+
input_specs,
147+
)
148+
64149

65150
if __name__ == "__main__":
66151
run_tests()

tests/py/dynamo/conversion/test_lt_aten.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -58,6 +59,90 @@ def forward(self, lhs_val):
5859
inputs,
5960
)
6061

62+
@parameterized.expand(
63+
[
64+
((1,), (3,), (5,)),
65+
((1, 20), (2, 20), (3, 20)),
66+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
67+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
68+
]
69+
)
70+
def test_lt_tensor_dynamic_shape(self, min_shape, opt_shape, max_shape):
71+
class lt(nn.Module):
72+
def forward(self, lhs_val, rhs_val):
73+
return torch.ops.aten.lt.Tensor(lhs_val, rhs_val)
74+
75+
input_specs = [
76+
Input(
77+
dtype=torch.float32,
78+
min_shape=min_shape,
79+
opt_shape=opt_shape,
80+
max_shape=max_shape,
81+
),
82+
Input(
83+
dtype=torch.float32,
84+
min_shape=min_shape,
85+
opt_shape=opt_shape,
86+
max_shape=max_shape,
87+
),
88+
]
89+
self.run_test_with_dynamic_shape(
90+
lt(),
91+
input_specs,
92+
)
93+
94+
@parameterized.expand(
95+
[
96+
((1,), (3,), (5,)),
97+
((1, 20), (2, 20), (3, 20)),
98+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
99+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
100+
]
101+
)
102+
def test_lt_tensor_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
103+
class lt(nn.Module):
104+
def forward(self, lhs_val):
105+
return torch.ops.aten.lt.Tensor(lhs_val, torch.tensor(1))
106+
107+
input_specs = [
108+
Input(
109+
dtype=torch.int32,
110+
min_shape=min_shape,
111+
opt_shape=opt_shape,
112+
max_shape=max_shape,
113+
),
114+
]
115+
self.run_test_with_dynamic_shape(
116+
lt(),
117+
input_specs,
118+
)
119+
120+
@parameterized.expand(
121+
[
122+
((1,), (3,), (5,)),
123+
((1, 20), (2, 20), (3, 20)),
124+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
125+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
126+
]
127+
)
128+
def test_lt_scalar_dynamic_shape(self, min_shape, opt_shape, max_shape):
129+
class lt(nn.Module):
130+
def forward(self, lhs_val):
131+
return torch.ops.aten.lt.Scalar(lhs_val, 1.0)
132+
133+
input_specs = [
134+
Input(
135+
dtype=torch.int32,
136+
min_shape=min_shape,
137+
opt_shape=opt_shape,
138+
max_shape=max_shape,
139+
),
140+
]
141+
self.run_test_with_dynamic_shape(
142+
lt(),
143+
input_specs,
144+
)
145+
61146

62147
if __name__ == "__main__":
63148
run_tests()

0 commit comments

Comments
 (0)