Skip to content

dynamic shape argmax and argmin #3009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,7 +2871,7 @@ def aten_ops_resize(


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default)
@dynamo_tensorrt_converter(torch.ops.aten.argmax.default, supports_dynamic_shapes=True)
def aten_ops_argmax(
ctx: ConversionContext,
target: Target,
Expand All @@ -2891,7 +2891,7 @@ def aten_ops_argmax(


@enforce_tensor_types({0: (TRTTensor,)})
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default)
@dynamo_tensorrt_converter(torch.ops.aten.argmin.default, supports_dynamic_shapes=True)
def aten_ops_argmin(
ctx: ConversionContext,
target: Target,
Expand Down
62 changes: 57 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/impl/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
get_axes_for_reduce_op,
get_positive_dim,
set_layer_name,
get_trt_tensor,
has_dynamic_shape,
)
from torch_tensorrt.dynamo.types import TRTTensor
from torch_tensorrt.dynamo.conversion.impl.elementwise import convert_binary_elementwise
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.dynamo.types import TRTTensor


def argmax_argmin(
Expand All @@ -34,12 +38,60 @@ def argmax_argmin(
# 2. input rank == 1: TopK layer does not support 1 dimensional topk operation. Broadcast input to rank == 2
# 3. normal cases, no additional handlings
out = input
is_dynamic_present = has_dynamic_shape(input.shape)

if dim is None:
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
if is_dynamic_present and len(input.shape) != 1:
multiplier = get_trt_tensor(ctx, 1, name + "_shape")
for i in range(0, len(input.shape)):
if input.shape[i] != DYNAMIC_DIM:
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
input.shape[i],
)
else:
multiplier = convert_binary_elementwise(
ctx,
target,
source_ir,
name + f"_shape_{i}",
trt.ElementWiseOperation.PROD,
multiplier,
get_shape(
ctx,
target,
source_ir,
name + f"_shape_dim_stop_{i}",
input,
i,
),
)
# form shape tensor
new_shape_layer = ctx.net.add_concatenation(
[multiplier, get_trt_tensor(ctx, 1, name + "_one_shape")]
)
set_layer_name(
new_shape_layer, target, name + "_new_shape_concat", source_ir
)
concat_tensor = new_shape_layer.get_output(0)

reshape_dynamic_layer = ctx.net.add_shuffle(input)
reshape_dynamic_layer.set_input(1, concat_tensor)
set_layer_name(
reshape_dynamic_layer, target, name + "_reshape_layer", source_ir
)
out = reshape_dynamic_layer.get_output(0)

else:
new_shape = (*flatten_dims(input, 0, -1), 1)
out = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_flatten", input, new_shape
)
elif len(input.shape) == 1:
new_shape = (*input.shape, 1)
out = impl.shuffle.reshape(
Expand Down
9 changes: 9 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def run_test_with_dynamic_shape(
use_example_tensors=True,
pyt_inputs=None,
propagate_shapes=False,
check_dtype=True,
):
mod = self.generate_graph(
mod,
Expand All @@ -395,6 +396,14 @@ def run_test_with_dynamic_shape(
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_double=True)

if check_dtype:
output_dtypes = infer_module_output_dtypes(
mod,
input_specs,
compilation_settings.device,
truncate_double=compilation_settings.truncate_double,
)

interp = TRTInterpreter(
mod,
input_specs,
Expand Down
38 changes: 38 additions & 0 deletions tests/py/dynamo/conversion/test_argmax_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -36,6 +37,43 @@ def forward(self, input):

self.run_test(ArgMax(), input)

@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
# dim == None
("dim_1_none_true", (1,), (3,), (3,), None, True),
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
# common cases
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
]
)
def test_argmax_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmax.default(input, dim, keep_dim)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
ArgMax(),
input_specs,
)


if __name__ == "__main__":
run_tests()
37 changes: 37 additions & 0 deletions tests/py/dynamo/conversion/test_argmin_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,43 @@ def forward(self, input):

self.run_test(ArgMin(), input)

@parameterized.expand(
[
# input dimension == 1
("dim_1_keep_dim_true", (1,), (3,), (3,), 0, True),
("dim_1_keep_dim_false", (1,), (3,), (3,), 0, False),
# dim == None
("dim_1_none_true", (1,), (3,), (3,), None, True),
("dim_2_none_true", (1, 3), (3, 3), (3, 3), None, True),
("dim_3_none_false", (1, 3, 3), (3, 3, 3), (3, 3, 3), None, False),
# common cases
("dim_1_keep_dim_true", (3, 1), (3, 3), (3, 3), 1, True),
("dim_1_keep_dim_false", (3, 1), (3, 3), (3, 3), 1, False),
("dim_0_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, True),
("dim_0_keep_dim_false", (1, 4, 4), (4, 4, 4), (4, 4, 4), 0, False),
("dim_negative_keep_dim_true", (1, 4, 4), (4, 4, 4), (4, 4, 4), -3, True),
]
)
def test_argmin_dynamic(self, _, min_shape, opt_shape, max_shape, dim, keep_dim):
class ArgMin(nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.argmin.default(input, dim, keep_dim)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
),
]
self.run_test_with_dynamic_shape(
ArgMin(),
input_specs,
)


if __name__ == "__main__":
run_tests()
Loading