Skip to content

Commit baa2eb1

Browse files
authored
dynamic shape argmax and argmin (#3009)
1 parent fdaba9a commit baa2eb1

File tree

5 files changed

+143
-7
lines changed

5 files changed

+143
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2876,7 +2876,7 @@ def aten_ops_resize(
28762876

28772877

28782878
@enforce_tensor_types({0: (TRTTensor,)})
2879-
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
2879+
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default, supports_dynamic_shapes=True)
28802880
def aten_ops_argmax(
28812881
ctx: ConversionContext,
28822882
target: Target,
@@ -2896,7 +2896,7 @@ def aten_ops_argmax(
28962896

28972897

28982898
@enforce_tensor_types({0: (TRTTensor,)})
2899-
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
2899+
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default, supports_dynamic_shapes=True)
29002900
def aten_ops_argmin(
29012901
ctx: ConversionContext,
29022902
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@
1111
get_axes_for_reduce_op,
1212
get_positive_dim,
1313
set_layer_name,
14+
get_trt_tensor,
15+
has_dynamic_shape,
1416
)
15-
from torch_tensorrt.dynamo.types import TRTTensor
17+
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
18+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
1619
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
20+
from torch_tensorrt.dynamo.types import TRTTensor
1721

1822

1923
def argmax_argmin(
@@ -34,12 +38,60 @@ def argmax_argmin(
3438
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
3539
# 3. normal cases, no additional handlings
3640
out = input
41+
is_dynamic_present = has_dynamic_shape(input.shape)
3742

3843
if dim is None:
39-
new_shape = (*flatten_dims(input, 0, -1), 1)
40-
out = impl.shuffle.reshape(
41-
ctx, target, source_ir, f"{name}_flatten", input, new_shape
42-
)
44+
if is_dynamic_present and len(input.shape) != 1:
45+
multiplier = get_trt_tensor(ctx, 1, name + "_shape")
46+
for i in range(0, len(input.shape)):
47+
if input.shape[i] != DYNAMIC_DIM:
48+
multiplier = convert_binary_elementwise(
49+
ctx,
50+
target,
51+
source_ir,
52+
name + f"_shape_{i}",
53+
trt.ElementWiseOperation.PROD,
54+
multiplier,
55+
input.shape[i],
56+
)
57+
else:
58+
multiplier = convert_binary_elementwise(
59+
ctx,
60+
target,
61+
source_ir,
62+
name + f"_shape_{i}",
63+
trt.ElementWiseOperation.PROD,
64+
multiplier,
65+
get_shape(
66+
ctx,
67+
target,
68+
source_ir,
69+
name + f"_shape_dim_stop_{i}",
70+
input,
71+
i,
72+
),
73+
)
74+
# form shape tensor
75+
new_shape_layer = ctx.net.add_concatenation(
76+
[multiplier, get_trt_tensor(ctx, 1, name + "_one_shape")]
77+
)
78+
set_layer_name(
79+
new_shape_layer, target, name + "_new_shape_concat", source_ir
80+
)
81+
concat_tensor = new_shape_layer.get_output(0)
82+
83+
reshape_dynamic_layer = ctx.net.add_shuffle(input)
84+
reshape_dynamic_layer.set_input(1, concat_tensor)
85+
set_layer_name(
86+
reshape_dynamic_layer, target, name + "_reshape_layer", source_ir
87+
)
88+
out = reshape_dynamic_layer.get_output(0)
89+
90+
else:
91+
new_shape = (*flatten_dims(input, 0, -1), 1)
92+
out = impl.shuffle.reshape(
93+
ctx, target, source_ir, f"{name}_flatten", input, new_shape
94+
)
4395
elif len(input.shape) == 1:
4496
new_shape = (*input.shape, 1)
4597
out = impl.shuffle.reshape(

tests/py/dynamo/conversion/harness.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ def run_test_with_dynamic_shape(
382382
use_example_tensors=True,
383383
pyt_inputs=None,
384384
propagate_shapes=False,
385+
check_dtype=True,
385386
):
386387
mod = self.generate_graph(
387388
mod,
@@ -395,6 +396,14 @@ def run_test_with_dynamic_shape(
395396
# We replicate this behavior here
396397
compilation_settings = CompilationSettings(truncate_double=True)
397398

399+
if check_dtype:
400+
output_dtypes = infer_module_output_dtypes(
401+
mod,
402+
input_specs,
403+
compilation_settings.device,
404+
truncate_double=compilation_settings.truncate_double,
405+
)
406+
398407
interp = TRTInterpreter(
399408
mod,
400409
input_specs,

tests/py/dynamo/conversion/test_argmax_aten.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -36,6 +37,43 @@ def forward(self, input):
3637

3738
self.run_test(ArgMax(), input)
3839

40+
@parameterized.expand(
41+
[
42+
# input dimension == 1
43+
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
44+
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
45+
# dim == None
46+
("dim_1_none_true", (1,), (3,), (3,), None, True),
47+
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
48+
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
49+
# common cases
50+
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
51+
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
52+
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
53+
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
54+
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
55+
]
56+
)
57+
def test_argmax_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
58+
class ArgMax(nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
62+
def forward(self, input):
63+
return torch.ops.aten.argmax.default(input, dim, keep_dim)
64+
65+
input_specs = [
66+
Input(
67+
min_shape=min_shape,
68+
opt_shape=opt_shape,
69+
max_shape=max_shape,
70+
),
71+
]
72+
self.run_test_with_dynamic_shape(
73+
ArgMax(),
74+
input_specs,
75+
)
76+
3977

4078
if __name__ == "__main__":
4179
run_tests()

tests/py/dynamo/conversion/test_argmin_aten.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ def forward(self, input):
3636

3737
self.run_test(ArgMin(), input)
3838

39+
@parameterized.expand(
40+
[
41+
# input dimension == 1
42+
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
43+
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
44+
# dim == None
45+
("dim_1_none_true", (1,), (3,), (3,), None, True),
46+
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
47+
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
48+
# common cases
49+
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
50+
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
51+
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
52+
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
53+
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
54+
]
55+
)
56+
def test_argmin_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
57+
class ArgMin(nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, input):
62+
return torch.ops.aten.argmin.default(input, dim, keep_dim)
63+
64+
input_specs = [
65+
Input(
66+
min_shape=min_shape,
67+
opt_shape=opt_shape,
68+
max_shape=max_shape,
69+
),
70+
]
71+
self.run_test_with_dynamic_shape(
72+
ArgMin(),
73+
input_specs,
74+
)
75+
3976

4077
if __name__ == "__main__":
4178
run_tests()

0 commit comments

Comments
 (0)