Skip to content

Commit 9fd312f

Browse files
keehyunacehongwang
authored andcommitted
chore: dynamic shape support for clamp/min/max/floor_div/logical_and (#2977)
1 parent 19678b4 commit 9fd312f

File tree

7 files changed

+203
-124
lines changed

7 files changed

+203
-124
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -733,10 +733,10 @@ def aten_ops_where(
733733
)
734734

735735

736-
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
737-
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor)
738-
@dynamo_tensorrt_converter(torch.ops.aten.clip.default)
739-
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor)
736+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default, supports_dynamic_shapes=True)
737+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.Tensor, supports_dynamic_shapes=True)
738+
@dynamo_tensorrt_converter(torch.ops.aten.clip.default, supports_dynamic_shapes=True)
739+
@dynamo_tensorrt_converter(torch.ops.aten.clip.Tensor, supports_dynamic_shapes=True)
740740
def aten_ops_clamp(
741741
ctx: ConversionContext,
742742
target: Target,
@@ -1880,7 +1880,7 @@ def aten_ops_mul(
18801880
)
18811881

18821882

1883-
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default)
1883+
@dynamo_tensorrt_converter(torch.ops.aten.maximum.default, supports_dynamic_shapes=True)
18841884
def aten_ops_maximum(
18851885
ctx: ConversionContext,
18861886
target: Target,
@@ -1898,7 +1898,7 @@ def aten_ops_maximum(
18981898
)
18991899

19001900

1901-
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default)
1901+
@dynamo_tensorrt_converter(torch.ops.aten.minimum.default, supports_dynamic_shapes=True)
19021902
def aten_ops_minimum(
19031903
ctx: ConversionContext,
19041904
target: Target,
@@ -2019,8 +2019,13 @@ def aten_ops_pow(
20192019
)
20202020

20212021

2022-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.default)
2023-
@dynamo_tensorrt_converter(torch.ops.aten.floor_divide.Scalar)
2022+
@dynamo_tensorrt_converter(
2023+
torch.ops.aten.floor_divide.default, supports_dynamic_shapes=True
2024+
)
2025+
@dynamo_tensorrt_converter(
2026+
torch.ops.aten.floor_divide.Scalar, supports_dynamic_shapes=True
2027+
)
2028+
@dynamo_tensorrt_converter(operator.floordiv, supports_dynamic_shapes=True)
20242029
def aten_ops_floor_div(
20252030
ctx: ConversionContext,
20262031
target: Target,
@@ -2038,7 +2043,9 @@ def aten_ops_floor_div(
20382043
)
20392044

20402045

2041-
@dynamo_tensorrt_converter(torch.ops.aten.logical_and.default)
2046+
@dynamo_tensorrt_converter(
2047+
torch.ops.aten.logical_and.default, supports_dynamic_shapes=True
2048+
)
20422049
def aten_ops_logical_and(
20432050
ctx: ConversionContext,
20442051
target: Target,

tests/py/dynamo/conversion/test_clamp_aten.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@ def forward(self, x):
5454

5555
input_specs = [
5656
Input(
57-
shape=(-1, -1, 3, 3),
58-
dtype=torch.float32,
59-
shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))],
57+
min_shape=(1, 1, 3, 3),
58+
opt_shape=(3, 3, 3, 3),
59+
max_shape=(5, 5, 3, 3),
60+
dtype=torch.float,
6061
),
6162
]
62-
6363
self.run_test_with_dynamic_shape(TestModule(), input_specs)
6464
self.run_test_with_dynamic_shape(TestScalarModule(), input_specs)
6565

tests/py/dynamo/conversion/test_floor_div_aten.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,58 @@ def forward(self, lhs_val):
6161
inputs,
6262
)
6363

64+
@parameterized.expand(
65+
[
66+
(
67+
"2d_dim_dtype_half",
68+
(1, 1),
69+
(2, 2),
70+
(4, 4),
71+
torch.half,
72+
torch.half,
73+
),
74+
(
75+
"3d_dim_dtype_float",
76+
(1, 1, 1),
77+
(1, 2, 3),
78+
(3, 3, 3),
79+
torch.float,
80+
torch.float,
81+
),
82+
]
83+
)
84+
def test_floor_div_dynamic_shape(
85+
self, _, min_shape, opt_shape, max_shape, type, output_type
86+
):
87+
class floor_div(nn.Module):
88+
def forward(self, lhs_val, rhs_val):
89+
return torch.ops.aten.floor_divide.default(lhs_val, rhs_val)
90+
91+
class floor_div_operator(nn.Module):
92+
def forward(self, lhs_val, rhs_val):
93+
return lhs_val // rhs_val
94+
95+
input_specs = [
96+
Input(
97+
min_shape=min_shape,
98+
opt_shape=opt_shape,
99+
max_shape=max_shape,
100+
dtype=type,
101+
),
102+
Input(
103+
min_shape=min_shape,
104+
opt_shape=opt_shape,
105+
max_shape=max_shape,
106+
dtype=type,
107+
),
108+
]
109+
self.run_test_with_dynamic_shape(
110+
floor_div(), input_specs, output_dtypes=[output_type]
111+
)
112+
self.run_test_with_dynamic_shape(
113+
floor_div_operator(), input_specs, output_dtypes=[output_type]
114+
)
115+
64116

65117
if __name__ == "__main__":
66118
run_tests()

tests/py/dynamo/conversion/test_logical_and_aten.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,45 @@ def forward(self, lhs_val, rhs_val):
2525
inputs,
2626
)
2727

28+
@parameterized.expand(
29+
[
30+
(
31+
"2d_dim_dtype_float",
32+
(1, 1),
33+
(2, 2),
34+
(4, 4),
35+
torch.float,
36+
),
37+
(
38+
"3d_dim_dtype_bool",
39+
(1, 1, 1),
40+
(1, 2, 3),
41+
(3, 3, 3),
42+
torch.bool,
43+
),
44+
]
45+
)
46+
def test_logical_and_dynamic_shape(self, _, min_shape, opt_shape, max_shape, type):
47+
class logical_and(nn.Module):
48+
def forward(self, lhs_val, rhs_val):
49+
return torch.ops.aten.logical_and.default(lhs_val, rhs_val)
50+
51+
input_specs = [
52+
Input(
53+
min_shape=min_shape,
54+
opt_shape=opt_shape,
55+
max_shape=max_shape,
56+
dtype=type,
57+
),
58+
Input(
59+
min_shape=min_shape,
60+
opt_shape=opt_shape,
61+
max_shape=max_shape,
62+
dtype=type,
63+
),
64+
]
65+
self.run_test_with_dynamic_shape(logical_and(), input_specs)
66+
2867

2968
if __name__ == "__main__":
3069
run_tests()

tests/py/dynamo/conversion/test_maximum_aten.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TestMaximumConverter(DispatchTestCase):
1717
def test_maximum(self, _, shape):
1818
class Maximum(nn.Module):
1919
def forward(self, lhs_val, rhs_val):
20-
return torch.maximum(lhs_val, rhs_val)
20+
return torch.ops.aten.maximum.default(lhs_val, rhs_val)
2121

2222
inputs = [torch.randn(shape), torch.randn(shape)]
2323
self.run_test(
@@ -26,6 +26,51 @@ def forward(self, lhs_val, rhs_val):
2626
use_dynamo_tracer=True,
2727
)
2828

29+
@parameterized.expand(
30+
[
31+
(
32+
"2d_dim_dtype_half",
33+
(1, 1),
34+
(2, 2),
35+
(4, 4),
36+
torch.half,
37+
torch.half,
38+
),
39+
(
40+
"3d_dim_dtype_float",
41+
(1, 1, 1),
42+
(1, 2, 3),
43+
(3, 3, 3),
44+
torch.float,
45+
torch.float,
46+
),
47+
]
48+
)
49+
def test_maximum_dynamic_shape(
50+
self, _, min_shape, opt_shape, max_shape, type, output_type
51+
):
52+
class Maximum(nn.Module):
53+
def forward(self, lhs_val, rhs_val):
54+
return torch.ops.aten.maximum.default(lhs_val, rhs_val)
55+
56+
input_specs = [
57+
Input(
58+
min_shape=min_shape,
59+
opt_shape=opt_shape,
60+
max_shape=max_shape,
61+
dtype=type,
62+
),
63+
Input(
64+
min_shape=min_shape,
65+
opt_shape=opt_shape,
66+
max_shape=max_shape,
67+
dtype=type,
68+
),
69+
]
70+
self.run_test_with_dynamic_shape(
71+
Maximum(), input_specs, output_dtypes=[output_type]
72+
)
73+
2974

3075
if __name__ == "__main__":
3176
run_tests()

tests/py/dynamo/conversion/test_minimum_aten.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TestMinimumConverter(DispatchTestCase):
1717
def test_minimum(self, _, shape):
1818
class Minimum(nn.Module):
1919
def forward(self, lhs_val, rhs_val):
20-
return torch.minimum(lhs_val, rhs_val)
20+
return torch.ops.aten.minimum.default(lhs_val, rhs_val)
2121

2222
inputs = [torch.randn(shape), torch.randn(shape)]
2323
self.run_test(
@@ -26,6 +26,51 @@ def forward(self, lhs_val, rhs_val):
2626
use_dynamo_tracer=True,
2727
)
2828

29+
@parameterized.expand(
30+
[
31+
(
32+
"2d_dim_dtype_half",
33+
(1, 1),
34+
(2, 2),
35+
(4, 4),
36+
torch.half,
37+
torch.half,
38+
),
39+
(
40+
"3d_dim_dtype_float",
41+
(1, 1, 1),
42+
(1, 2, 3),
43+
(3, 3, 3),
44+
torch.float,
45+
torch.float,
46+
),
47+
]
48+
)
49+
def test_minimum_dynamic_shape(
50+
self, _, min_shape, opt_shape, max_shape, type, output_type
51+
):
52+
class Minimum(nn.Module):
53+
def forward(self, lhs_val, rhs_val):
54+
return torch.ops.aten.minimum.default(lhs_val, rhs_val)
55+
56+
input_specs = [
57+
Input(
58+
min_shape=min_shape,
59+
opt_shape=opt_shape,
60+
max_shape=max_shape,
61+
dtype=type,
62+
),
63+
Input(
64+
min_shape=min_shape,
65+
opt_shape=opt_shape,
66+
max_shape=max_shape,
67+
dtype=type,
68+
),
69+
]
70+
self.run_test_with_dynamic_shape(
71+
Minimum(), input_specs, output_dtypes=[output_type]
72+
)
73+
2974

3075
if __name__ == "__main__":
3176
run_tests()

0 commit comments

Comments
 (0)