Skip to content

Commit cb6ef31

Browse files
committed
dynamic shape argmax and argmin
1 parent 0d4af77 commit cb6ef31

File tree

5 files changed

+147
-7
lines changed

5 files changed

+147
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,7 +2824,7 @@ def aten_ops_resize(
28242824

28252825

28262826
@enforce_tensor_types({0: (TRTTensor,)})
2827-
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
2827+
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default, supports_dynamic_shapes=True)
28282828
def aten_ops_argmax(
28292829
ctx: ConversionContext,
28302830
target: Target,
@@ -2844,7 +2844,7 @@ def aten_ops_argmax(
28442844

28452845

28462846
@enforce_tensor_types({0: (TRTTensor,)})
2847-
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
2847+
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default, supports_dynamic_shapes=True)
28482848
def aten_ops_argmin(
28492849
ctx: ConversionContext,
28502850
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/topk.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@
1010
flatten_dims,
1111
get_axes_for_reduce_op,
1212
get_positive_dim,
13+
get_trt_tensor,
14+
)
15+
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
16+
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
17+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
18+
from torch_tensorrt.fx.converters.converter_utils import (
19+
has_dynamic_shape,
20+
set_layer_name,
1321
)
14-
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1522
from torch_tensorrt.fx.types import TRTTensor
1623

1724

@@ -33,12 +40,60 @@ def argmax_argmin(
3340
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
3441
# 3. normal cases, no additional handlings
3542
out = input
43+
is_dynamic_present = has_dynamic_shape(input.shape)
3644

3745
if dim is None:
38-
new_shape = (*flatten_dims(input, 0, -1), 1)
39-
out = impl.shuffle.reshape(
40-
ctx, target, source_ir, f"{name}_flatten", input, new_shape
41-
)
46+
if is_dynamic_present and len(input.shape) != 1:
47+
multiplier = get_trt_tensor(ctx, 1, name + "_shape")
48+
for i in range(0, len(input.shape)):
49+
if input.shape[i] != DYNAMIC_DIM:
50+
multiplier = convert_binary_elementwise(
51+
ctx,
52+
target,
53+
source_ir,
54+
name + f"_shape_{i}",
55+
trt.ElementWiseOperation.PROD,
56+
multiplier,
57+
input.shape[i],
58+
)
59+
else:
60+
multiplier = convert_binary_elementwise(
61+
ctx,
62+
target,
63+
source_ir,
64+
name + f"_shape_{i}",
65+
trt.ElementWiseOperation.PROD,
66+
multiplier,
67+
get_shape(
68+
ctx,
69+
target,
70+
source_ir,
71+
name + f"_shape_dim_stop_{i}",
72+
input,
73+
i,
74+
),
75+
)
76+
# form shape tensor
77+
new_shape_layer = ctx.net.add_concatenation(
78+
[multiplier, get_trt_tensor(ctx, 1, name + "_one_shape" + "shape0")]
79+
)
80+
set_layer_name(
81+
new_shape_layer, target, name + "_new_shape_concat", source_ir
82+
)
83+
concat_tensor = new_shape_layer.get_output(0)
84+
85+
reshape_dynamic_layer = ctx.net.add_shuffle(input)
86+
reshape_dynamic_layer.set_input(1, concat_tensor)
87+
set_layer_name(
88+
reshape_dynamic_layer, target, name + "_reshape_layer", source_ir
89+
)
90+
out = reshape_dynamic_layer.get_output(0)
91+
92+
else:
93+
new_shape = (*flatten_dims(input, 0, -1), 1)
94+
out = impl.shuffle.reshape(
95+
ctx, target, source_ir, f"{name}_flatten", input, new_shape
96+
)
4297
elif len(input.shape) == 1:
4398
new_shape = (*input.shape, 1)
4499
out = impl.shuffle.reshape(

tests/py/dynamo/conversion/harness.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ def run_test_with_dynamic_shape(
375375
use_example_tensors=True,
376376
pyt_inputs=None,
377377
propagate_shapes=False,
378+
check_dtype=True,
378379
):
379380
mod = self.generate_graph(
380381
mod,
@@ -388,6 +389,14 @@ def run_test_with_dynamic_shape(
388389
# We replicate this behavior here
389390
compilation_settings = CompilationSettings(truncate_double=True)
390391

392+
if check_dtype:
393+
output_dtypes = infer_module_output_dtypes(
394+
mod,
395+
input_specs,
396+
compilation_settings.device,
397+
truncate_double=compilation_settings.truncate_double,
398+
)
399+
391400
interp = TRTInterpreter(
392401
mod,
393402
input_specs,

tests/py/dynamo/conversion/test_argmax_aten.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
56

67
from .harness import DispatchTestCase
78

@@ -36,6 +37,44 @@ def forward(self, input):
3637

3738
self.run_test(ArgMax(), input)
3839

40+
@parameterized.expand(
41+
[
42+
# input dimension == 1
43+
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
44+
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
45+
# # dim == None
46+
# ("dim_1_none_true", (1,), (3,), (3,), None, True),
47+
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
48+
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
49+
# common cases
50+
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
51+
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
52+
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
53+
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
54+
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
55+
]
56+
)
57+
def test_argmax_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
58+
class ArgMax(nn.Module):
59+
def __init__(self):
60+
super().__init__()
61+
62+
def forward(self, input):
63+
print("hereee===")
64+
return torch.ops.aten.argmax.default(input, dim, keep_dim)
65+
66+
input_specs = [
67+
Input(
68+
min_shape=min_shape,
69+
opt_shape=opt_shape,
70+
max_shape=max_shape,
71+
),
72+
]
73+
self.run_test_with_dynamic_shape(
74+
ArgMax(),
75+
input_specs,
76+
)
77+
3978

4079
if __name__ == "__main__":
4180
run_tests()

tests/py/dynamo/conversion/test_argmin_aten.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,43 @@ def forward(self, input):
3636

3737
self.run_test(ArgMin(), input)
3838

39+
@parameterized.expand(
40+
[
41+
# input dimension == 1
42+
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
43+
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
44+
# # dim == None
45+
# ("dim_1_none_true", (1,), (3,), (3,), None, True),
46+
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
47+
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
48+
# common cases
49+
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
50+
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
51+
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
52+
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
53+
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
54+
]
55+
)
56+
def test_argmin_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
57+
class ArgMin(nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, input):
62+
return torch.ops.aten.argmin.default(input, dim, keep_dim)
63+
64+
input_specs = [
65+
Input(
66+
min_shape=min_shape,
67+
opt_shape=opt_shape,
68+
max_shape=max_shape,
69+
),
70+
]
71+
self.run_test_with_dynamic_shape(
72+
ArgMin(),
73+
input_specs,
74+
)
75+
3976

4077
if __name__ == "__main__":
4178
run_tests()

0 commit comments

Comments
 (0)