diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index ec9d6b8c78..f6bafe24c1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -622,8 +622,8 @@ def aten_ops_quantize_fp8( ) -@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) -@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( ctx: ConversionContext, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py index 45bdefcd80..dd6a2b9863 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -3,10 +3,11 @@ from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext -from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim -from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_positive_dim, + set_layer_name, +) from torch_tensorrt.fx.types import TRTTensor -from torch_tensorrt.fx.utils import get_dynamic_dims def squeeze( @@ -25,8 +26,8 @@ def squeeze( if isinstance(dim, int): dims.append(dim) else: - for dim in dim: - dims.append(dim) + for d in dim: + dims.append(d) new_dims = [] for dim in dims: @@ -36,17 +37,22 @@ def squeeze( ) assert input.shape[dim] != -1, "We don't support squeeze dynamic dim." - assert ( - len(get_dynamic_dims(input.shape)) <= 1 - ), "Currently more than one dynamic dim for input to squeeze is not supported." new_dims.append(dim) - output_shape = [] + dim_to_remove = [] + new_permutation = [] for i, s in enumerate(input.shape): if (i in new_dims) and s == 1: - continue - output_shape.append(s) + dim_to_remove.append(i) + else: + new_permutation.append(i) + # If number of reshape dimensions is less than input, 0s are resolved by aligning + # the most significant dimensions of input + output_shape = tuple([0] * len(new_permutation)) + new_permutation += dim_to_remove + layer = ctx.net.add_shuffle(input) - layer.reshape_dims = tuple(output_shape) + layer.first_transpose = new_permutation + layer.reshape_dims = output_shape set_layer_name(layer, target, name, source_ir) return layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_squeeze_aten.py b/tests/py/dynamo/conversion/test_squeeze_aten.py index 88483072ae..9dc786525b 100644 --- a/tests/py/dynamo/conversion/test_squeeze_aten.py +++ b/tests/py/dynamo/conversion/test_squeeze_aten.py @@ -43,23 +43,50 @@ def forward(self, x): ) -class TestSqueezeConverter(DispatchTestCase): +class TestSqueezeConverterDynamic(DispatchTestCase): @parameterized.expand( [ - ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), - ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + ( + "5d_two_dynamic_shape_-1", + (0,), + (1, 1, 1, 1, 1), + (1, 2, 1, 2, 1), + (1, 4, 1, 3, 1), + ), + ( + "5d_two_dynamic_shape_-2", + (0, 2), + (1, 1, 1, 1, 1), + (1, 2, 1, 2, 1), + (1, 4, 1, 3, 1), + ), + ( + "5d_three_dynamic_shape_-2", + (0, 4), + (1, 1, 1, 1, 1), + (1, 2, 4, 2, 1), + (1, 4, 4, 3, 1), + ), + ( + "4d_two_dynamic_shape_-2", + (0, 2), + (1, 1, 2, 1), + (1, 2, 2, 2), + (1, 4, 2, 3), + ), ] ) - def test_squeeze(self, _, dim, init_size, shape_range): + def test_squeeze(self, _, dim, min_shape, opt_shape, max_shape): class Squeeze(nn.Module): def forward(self, x): - return torch.ops.aten.squeeze.dim(x, dim) + return torch.ops.aten.squeeze.dims(x, dim) input_specs = [ Input( - shape=init_size, - dtype=torch.float32, - shape_ranges=shape_range, + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, ), ] self.run_test_with_dynamic_shape(