diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 721a0a546c..715001f56e 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1779,6 +1779,7 @@ def aten_ops_add( ) +@dynamo_tensorrt_converter(operator.mul, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mul.Tensor, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mul.Scalar, supports_dynamic_shapes=True) def aten_ops_mul( diff --git a/py/torch_tensorrt/dynamo/partitioning/common.py b/py/torch_tensorrt/dynamo/partitioning/common.py index 9ac677484f..fdc55126ee 100644 --- a/py/torch_tensorrt/dynamo/partitioning/common.py +++ b/py/torch_tensorrt/dynamo/partitioning/common.py @@ -35,8 +35,16 @@ def construct_dynamic_input( node = dim.node expr = node.expr shape_env = node.shape_env - var_range = shape_env.var_to_range.get(expr, None) - var_val = shape_env.var_to_val.get(expr, None) + # An expr can be a independent SymInt node (eg: s0 or s1) or a composition of them eg: (48*s0 or s0*s1). + # In the case of expr which has symbolic computation, bound_sympy evaluates them. + # https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.ShapeEnv.html#torch.fx.experimental.symbolic_shapes.ShapeEnv.bound_sympy + # expr.xreplace replaces the symbolic variables with their current values and computes the expression. + var_range = shape_env.var_to_range.get(expr, None) or shape_env.bound_sympy( + expr + ) + var_val = shape_env.var_to_val.get(expr, None) or expr.xreplace( + shape_env.var_to_val + ) assert var_range, var_val # Torchdynamo 0/1 specialization outlier if var_range.lower == 2: diff --git a/tests/py/dynamo/models/test_dyn_models.py b/tests/py/dynamo/models/test_dyn_models.py index 4c6b98e555..3fd34de2ea 100644 --- a/tests/py/dynamo/models/test_dyn_models.py +++ b/tests/py/dynamo/models/test_dyn_models.py @@ -64,11 +64,6 @@ def forward(self, x): cos_sim > COSINE_THRESHOLD, msg=f"test_dyn_full_compile model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() @unittest.skip( @@ -128,12 +123,6 @@ def forward(self, x): msg=f"test_base_dynamic_fallback model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_view(ir): @@ -185,12 +174,6 @@ def forward(self, x): msg=f"test_view model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_resnet_dynamic(ir): @@ -234,12 +217,6 @@ def test_resnet_dynamic(ir): msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - - with torch.no_grad(): - torch.cuda.empty_cache() - @pytest.mark.unit def test_view(ir): @@ -284,8 +261,52 @@ def forward(self, x): msg=f"test_base_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # Clean up model env - torch._dynamo.reset() - with torch.no_grad(): - torch.cuda.empty_cache() +@pytest.mark.unit +def test_linear(ir): + """ + Tests the model with linear op and operator.mul (added internally by PyTorch) + with dynamic shapes + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear1(x) + + model = MyModule().eval().cuda() + + compile_spec = { + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "ir": ir, + "min_block_size": 1, + } + inputs_bs2 = torch.randn(2, 2, 10).to("cuda") + if ir == "torch_compile": + torch._dynamo.mark_dynamic(inputs_bs2, 0, min=1, max=10) + torch._dynamo.mark_dynamic(inputs_bs2, 1, min=1, max=10) + # Compile the model + trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) + trt_model(inputs_bs2) + elif ir == "dynamo": + dynamic_shapes = ( + { + 0: torch.export.Dim("batch_size", min=1, max=10), + 1: torch.export.Dim("seq_len", max=10), + }, + ) + exp_program = torch.export.export( + model, (inputs_bs2,), dynamic_shapes=dynamic_shapes + ) + trt_model = torchtrt.dynamo.compile(exp_program, [inputs_bs2], **compile_spec) + + input_bs6_s3 = torch.randn((6, 3, 10)).to("cuda") + cos_sim = cosine_similarity(model(input_bs6_s3), trt_model(input_bs6_s3)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_linear model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + )