diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 3da1b09fba..513903b1f2 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2502,8 +2502,8 @@ def aten_ops_convolution( ) -@dynamo_tensorrt_converter(torch.ops.aten.linear.default) -@dynamo_tensorrt_converter(torch.ops.aten.linear) +@dynamo_tensorrt_converter(torch.ops.aten.linear.default, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.linear, supports_dynamic_shapes=True) def aten_ops_linear( ctx: ConversionContext, target: Target, diff --git a/tests/py/dynamo/conversion/test_linear_aten.py b/tests/py/dynamo/conversion/test_linear_aten.py index 615f40fb2f..f53eb98f33 100644 --- a/tests/py/dynamo/conversion/test_linear_aten.py +++ b/tests/py/dynamo/conversion/test_linear_aten.py @@ -1,6 +1,8 @@ import torch +import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -8,25 +10,35 @@ class TestLinearConverter(DispatchTestCase): @parameterized.expand( [ - ("default", [1, 512], True, torch.ops.aten.linear.default), - ("matrix", [5, 512], True, torch.ops.aten.linear.default), - ("no_bias", [1, 512], False, torch.ops.aten.linear.default), + ( + "default", + [1, 512], + True, + ), + ( + "matrix", + [5, 512], + True, + ), + ( + "no_bias", + [1, 512], + False, + ), ( "multi_dim_matrix", [4, 5, 512], True, - torch.ops.aten.linear.default, ), ( "multi_dim_matrix", [4, 5, 512], False, - torch.ops.aten.linear.default, ), ] ) - def test_linear(self, test_name, shape, bias, op): - class TestModule(torch.nn.Module): + def test_linear(self, test_name, shape, bias): + class linear(nn.Module): def __init__(self): super().__init__() self.weight = torch.randn((256, 512)) @@ -39,37 +51,80 @@ def forward(self, x): return torch.ops.aten.linear.default(x, self.weight, self.bias) inputs = [torch.randn(shape)] - self.run_test(TestModule(), inputs) + self.run_test(linear(), inputs) # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. - # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. - - # def test_linear_with_dynamic_shape(self): - # class TestModule(torch.nn.Module): - # def __init__(self): - # super().__init__() - # self.linear = torch.nn.Linear(512, 256) + # # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. + @parameterized.expand( + [ + ( + "2d_dim", + (1, 512), + (2, 512), + (3, 512), + torch.float32, + (256, 512), + None, + ), + ( + "3d_one_dynamic_dim", + (1, 1, 512), + (2, 2, 512), + (3, 3, 512), + torch.float32, + (256, 512), + (256,), + ), + ( + "3d_two_dynamic_dim_bias", + (1, 1, 512), + (2, 2, 512), + (3, 3, 512), + torch.float32, + (256, 512), + (256,), + ), + ( + "3d_two_dynamic_dim_no_bias", + (1, 1, 512), + (2, 2, 512), + (3, 3, 512), + torch.float32, + (256, 512), + None, + ), + ] + ) + def test_linear_with_dynamic_shape( + self, _, min_shape, opt_shape, max_shape, type, weight_shape, bias_shape + ): + class linear(nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.rand(weight_shape) - # def forward(self, x): - # return self.linear(x) + if bias_shape: + self.bias = torch.randn(bias_shape) + else: + self.bias = None - # input_specs = [ - # Input( - # shape=(-1, 3, 512), - # dtype=torch.float32, - # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(), - # input_specs, - # expected_ops={torch.ops.aten.addmm.default}, - # ) + def forward(self, x): + return torch.ops.aten.linear.default(x, self.weight, self.bias) - ## Testing with (-1, -1, 512) results into following error: - ## AssertionError: Currently we only support one dynamic dim for linear and it can't be the last dim. + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=type, + ), + ] + self.run_test_with_dynamic_shape( + linear(), + input_specs, + ) if __name__ == "__main__":