Skip to content

Commit f5167a8

Browse files
authored
feat: dynamic shape support for atan/asinh/acosh/atanh/atan2/ceil (#2959)
1 parent feb4d84 commit f5167a8

File tree

8 files changed

+396
-26
lines changed

8 files changed

+396
-26
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,7 @@ def aten_ops_acos(
15981598
)
15991599

16001600

1601-
@dynamo_tensorrt_converter(torch.ops.aten.atan.default)
1601+
@dynamo_tensorrt_converter(torch.ops.aten.atan.default, supports_dynamic_shapes=True)
16021602
def aten_ops_atan(
16031603
ctx: ConversionContext,
16041604
target: Target,
@@ -1615,7 +1615,7 @@ def aten_ops_atan(
16151615
)
16161616

16171617

1618-
@dynamo_tensorrt_converter(torch.ops.aten.asinh.default)
1618+
@dynamo_tensorrt_converter(torch.ops.aten.asinh.default, supports_dynamic_shapes=True)
16191619
def aten_ops_asinh(
16201620
ctx: ConversionContext,
16211621
target: Target,
@@ -1632,7 +1632,7 @@ def aten_ops_asinh(
16321632
)
16331633

16341634

1635-
@dynamo_tensorrt_converter(torch.ops.aten.acosh.default)
1635+
@dynamo_tensorrt_converter(torch.ops.aten.acosh.default, supports_dynamic_shapes=True)
16361636
def aten_ops_acosh(
16371637
ctx: ConversionContext,
16381638
target: Target,
@@ -1649,7 +1649,7 @@ def aten_ops_acosh(
16491649
)
16501650

16511651

1652-
@dynamo_tensorrt_converter(torch.ops.aten.atanh.default)
1652+
@dynamo_tensorrt_converter(torch.ops.aten.atanh.default, supports_dynamic_shapes=True)
16531653
def aten_ops_atanh(
16541654
ctx: ConversionContext,
16551655
target: Target,
@@ -1666,7 +1666,7 @@ def aten_ops_atanh(
16661666
)
16671667

16681668

1669-
@dynamo_tensorrt_converter(torch.ops.aten.atan2.default)
1669+
@dynamo_tensorrt_converter(torch.ops.aten.atan2.default, supports_dynamic_shapes=True)
16701670
@enforce_tensor_types(
16711671
{
16721672
0: (TRTTensor,),
@@ -1690,7 +1690,7 @@ def aten_ops_atan2(
16901690
)
16911691

16921692

1693-
@dynamo_tensorrt_converter(torch.ops.aten.atan2.out)
1693+
@dynamo_tensorrt_converter(torch.ops.aten.atan2.out, supports_dynamic_shapes=True)
16941694
def aten_ops_atan2_out(
16951695
ctx: ConversionContext,
16961696
target: Target,
@@ -1706,7 +1706,7 @@ def aten_ops_atan2_out(
17061706
return out_return
17071707

17081708

1709-
@dynamo_tensorrt_converter(torch.ops.aten.ceil.default)
1709+
@dynamo_tensorrt_converter(torch.ops.aten.ceil.default, supports_dynamic_shapes=True)
17101710
def aten_ops_ceil(
17111711
ctx: ConversionContext,
17121712
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, Union
22

33
import numpy as np
4+
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
45
import tensorrt as trt
56
import torch
67
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -13,6 +14,7 @@
1314
cast_int_or_float_to_bool,
1415
cast_trt_tensor,
1516
get_trt_tensor,
17+
has_dynamic_shape,
1618
)
1719
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1820
convert_binary_elementwise,
@@ -332,25 +334,56 @@ def atan2(
332334
y_positive,
333335
)
334336

335-
# on x or y-axis
336-
pi_over_2_tensor = get_trt_tensor(
337-
ctx,
338-
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
339-
f"{name}_pi_over_2_tensor",
340-
dtype=trt.float32,
341-
)
342-
minus_pi_over_2_tensor = get_trt_tensor(
343-
ctx,
344-
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
345-
f"{name}_minus_pi_over_2_tensor",
346-
dtype=trt.float32,
347-
)
348-
zero_tensor = get_trt_tensor(
349-
ctx,
350-
np.zeros(input.shape, dtype=np.float32),
351-
f"{name}_zero_tensor",
352-
dtype=trt.float32,
353-
)
337+
if has_dynamic_shape(input.shape):
338+
pi_over_2_tensor = convert_binary_elementwise(
339+
ctx,
340+
target,
341+
source_ir,
342+
f"{name}_pi_over_2_tensor",
343+
trt.ElementWiseOperation.PROD,
344+
(pi_value / 2),
345+
input,
346+
)
347+
348+
minus_pi_over_2_tensor = convert_binary_elementwise(
349+
ctx,
350+
target,
351+
source_ir,
352+
f"{name}_minus_pi_over_2_tensor",
353+
trt.ElementWiseOperation.PROD,
354+
(-pi_value / 2),
355+
input,
356+
)
357+
zero_tensor = convert_binary_elementwise(
358+
ctx,
359+
target,
360+
source_ir,
361+
f"{name}_zero_tensor",
362+
trt.ElementWiseOperation.PROD,
363+
0,
364+
input,
365+
)
366+
else:
367+
# on x or y-axis
368+
pi_over_2_tensor = get_trt_tensor(
369+
ctx,
370+
(pi_value / 2) * np.ones(input.shape, dtype=np.float32),
371+
f"{name}_pi_over_2_tensor",
372+
dtype=trt.float32,
373+
)
374+
375+
minus_pi_over_2_tensor = get_trt_tensor(
376+
ctx,
377+
(-pi_value / 2) * np.ones(input.shape, dtype=np.float32),
378+
f"{name}_minus_pi_over_2_tensor",
379+
dtype=trt.float32,
380+
)
381+
zero_tensor = get_trt_tensor(
382+
ctx,
383+
np.zeros(input.shape, dtype=np.float32),
384+
f"{name}_zero_tensor",
385+
dtype=trt.float32,
386+
)
354387

355388
# π/2 if x>0 and y=0,
356389
pi_over_2_output = impl.condition.select(

tests/py/dynamo/conversion/test_acosh_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,53 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"2d_dim_dtype_half",
52+
(1, 1),
53+
(2, 2),
54+
(4, 4),
55+
torch.half,
56+
torch.half,
57+
),
58+
(
59+
"3d_dim_dtype_float",
60+
(1, 1, 1),
61+
(1, 2, 3),
62+
(3, 3, 3),
63+
torch.float,
64+
torch.float,
65+
),
66+
(
67+
"3d_dim_dtype_int32",
68+
(1, 1, 1),
69+
(1, 2, 4),
70+
(2, 3, 5),
71+
torch.int32,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_dynamic_shape_acosh(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class acosh(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.acosh.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
acosh(), input_specs, output_dtypes=[output_type]
93+
)
94+
4795

4896
if __name__ == "__main__":
4997
run_tests()

tests/py/dynamo/conversion/test_asinh_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -44,6 +45,53 @@ def forward(self, input):
4445
inputs,
4546
)
4647

48+
@parameterized.expand(
49+
[
50+
(
51+
"2d_dim_dtype_half",
52+
(1, 1),
53+
(2, 2),
54+
(4, 4),
55+
torch.half,
56+
torch.half,
57+
),
58+
(
59+
"3d_dim_dtype_float",
60+
(1, 1, 1),
61+
(1, 2, 3),
62+
(3, 3, 3),
63+
torch.float,
64+
torch.float,
65+
),
66+
(
67+
"3d_dim_dtype_int32",
68+
(1, 1, 1),
69+
(1, 2, 4),
70+
(2, 3, 5),
71+
torch.int32,
72+
torch.float,
73+
),
74+
]
75+
)
76+
def test_dynamic_shape_asinh(
77+
self, _, min_shape, opt_shape, max_shape, type, output_type
78+
):
79+
class asinh(nn.Module):
80+
def forward(self, input):
81+
return torch.ops.aten.asinh.default(input)
82+
83+
input_specs = [
84+
Input(
85+
min_shape=min_shape,
86+
opt_shape=opt_shape,
87+
max_shape=max_shape,
88+
dtype=type,
89+
),
90+
]
91+
self.run_test_with_dynamic_shape(
92+
asinh(), input_specs, output_dtypes=[output_type]
93+
)
94+
4795

4896
if __name__ == "__main__":
4997
run_tests()

0 commit comments

Comments
 (0)