diff --git a/docs/_sources/tutorials/ptq.rst.txt b/docs/_sources/tutorials/ptq.rst.txt index b62457109f..75b27c8409 100644 --- a/docs/_sources/tutorials/ptq.rst.txt +++ b/docs/_sources/tutorials/ptq.rst.txt @@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well. From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on -CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq +CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq .. _writing_ptq_python: @@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode. calibrator=calibrator) If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient. -For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py -and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py +For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py +and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py Citations ^^^^^^^^^^^ diff --git a/examples/fx/hugging_face_torchdynamo_example.py b/examples/fx/hugging_face_torchdynamo_example.py index 3d4d91d3f8..388ccf2e47 100644 --- a/examples/fx/hugging_face_torchdynamo_example.py +++ b/examples/fx/hugging_face_torchdynamo_example.py @@ -15,12 +15,12 @@ ) from transformers import BertConfig, ReformerConfig, XLNetModel, XLNetConfig -import torchdynamo -from torchdynamo.optimizations import backends -from torchdynamo.optimizations.training import aot_autograd_debug_strategy1 -from torchdynamo.optimizations.training import aot_autograd_speedup_strategy -from torchdynamo.testing import collect_results -from torchdynamo.testing import same +import torch._dynamo as torchdynamo +from torch._dynamo.optimizations import backends +from torch._dynamo.optimizations.training import aot_autograd_debug_strategy1 +from torch._dynamo.optimizations.training import aot_autograd_speedup_strategy +from torch._dynamo.testing import collect_results +from torch._dynamo.testing import same torch.backends.cuda.matmul.allow_tf32 = True diff --git a/examples/fx/torchdynamo_example.py b/examples/fx/torchdynamo_example.py index a2e7627800..0d640de68c 100644 --- a/examples/fx/torchdynamo_example.py +++ b/examples/fx/torchdynamo_example.py @@ -3,11 +3,11 @@ from dataclasses import dataclass, field, replace import torch -import torchdynamo +import torch._dynamo as torchdynamo import torchvision from torch_tensorrt.fx.lower import compile from torch_tensorrt.fx.utils import LowerPrecision -from torchdynamo.optimizations import backends +from torch._dynamo.optimizations import backends """ The purpose of this example is to demostrate the lowering flow to TRT and Torchdynamo diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index f5278b1d07..f4a0b49a93 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -25,6 +25,7 @@ trt_transposed_linear, trt_transposed_matmul, ) +from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -3371,6 +3372,9 @@ def acc_ops_gelu( name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] + approximate = kwargs["approximate"] + if approximate is not "none": + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") if not isinstance(input_val, TRTTensor): raise RuntimeError( f"GELU received input {input_val} that is not part " diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 3d3eb1ae92..61bd232421 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -31,6 +31,7 @@ def compile( module: nn.Module, input, + min_acc_module_size: int = 10, max_batch_size: int = 2048, max_workspace_size=1 << 25, explicit_batch_dimension=False, @@ -51,6 +52,7 @@ def compile( module: Original module for lowering. input: Input for module. max_batch_size: Maximum batch size (must be >= 1 to be set, 0 means not set) + min_acc_module_size: Minimal number of nodes for an accelerated submodule max_workspace_size: Maximum size of workspace given to TensorRT. explicit_batch_dimension: Use explicit batch dimension in TensorRT if set True, otherwise use implicit batch dimension. lower_precision: lower_precision config given to TRTModule. @@ -70,6 +72,7 @@ def compile( lower_setting = LowerSetting( max_batch_size=max_batch_size, + min_acc_module_size=min_acc_module_size, max_workspace_size=max_workspace_size, explicit_batch_dimension=explicit_batch_dimension, lower_precision=lower_precision, @@ -268,6 +271,7 @@ def __call__( module: nn.Module, inputs: Input, additional_inputs: Optional[Input] = None, + fp16_conversion_fn: Optional[Callable[[Input], Input]] = None, ) -> nn.Module: lower_setting = self.lower_pass_manager_builder.lower_setting atol = lower_setting.correctness_atol @@ -284,9 +288,26 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module: == LowerPrecision.FP16 ): module.half() - inputs = tuple( - x.half() if x is not None and x.dtype == torch.float32 else x - for x in inputs + # A custom conversion function can be passed to the lowerer to + # handle inputs with custom types. By default, just handle + # tensors and NoneType. + if fp16_conversion_fn is None: + conversion_fn = ( + lambda x: x.half() + if x is not None and x.dtype == torch.float32 + else x + ) + else: + conversion_fn = fp16_conversion_fn + + inputs = tuple(conversion_fn(x) for x in inputs) + if lower_setting.is_aten: + pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( + inputs, additional_inputs + ) + else: + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( + inputs, additional_inputs ) if lower_setting.is_aten: pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py index 97360c75e5..9003bbf98a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py @@ -1,32 +1,51 @@ import torch import torch.nn as nn import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec class TestCatConverter(AccTestCase): - def test_cat(self): + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat(self, _, op): class Cat(nn.Module): def forward(self, x, y, z): - return torch.cat((x, y, z), 1) + return op((x, y, z), 1) inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) - def test_cat_neg(self): + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_neg(self, _, op): class Cat(nn.Module): def forward(self, x, y, z): - return torch.cat((x, y, z), -1) + return op((x, y, z), -1) inputs = [torch.randn(1, 2, 3), torch.randn(1, 2, 3), torch.randn(1, 2, 2)] self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) - def test_cat_with_dynamic_shape(self): + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_with_dynamic_shape(self, _, op): class Cat(nn.Module): def forward(self, x, y): x = x + y - return torch.cat((x, y), 0) + return op((x, y), 0) input_specs = [ InputTensorSpec( @@ -42,11 +61,17 @@ def forward(self, x, y): ] self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) - def test_cat_with_dynamic_shape_four_dimensions(self): + @parameterized.expand( + [ + param("cat", torch.cat), + param("concat", torch.concat), + ] + ) + def test_cat_with_dynamic_shape_four_dimensions(self, _, op): class Cat(nn.Module): def forward(self, x, y): x = x + y - return torch.cat((x, y), 0) + return op((x, y), 0) input_specs = [ InputTensorSpec( @@ -63,6 +88,14 @@ def forward(self, x, y): self.run_test_with_dynamic_shape(Cat(), input_specs, expected_ops={acc_ops.cat}) + def test_concat(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.concat((x, y, z), 1) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test(Cat(), inputs, expected_ops={acc_ops.cat}) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py index c33088a498..28ce1551fd 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py @@ -57,6 +57,39 @@ def forward(self, x): TestModule(), input_specs, expected_ops={acc_ops.gelu} ) + def test_gelu_module(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU() + + def forward(self, x): + return self.gelu(x) + + inputs = [torch.randn(3, 10, 20)] + self.run_test( + TestModule(), + inputs, + expected_ops={acc_ops.gelu}, + test_implicit_batch_dim=False, + ) + + def test_gelu_module_throw(self): + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + def forward(self, x): + return self.gelu(x) + + inputs = [torch.randn(3, 10, 20)] + self.run_test_with_assert_error( + TestModule(), + inputs, + expect_error=RuntimeError, + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py index 206d088a55..bfd93d2870 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_new_ones.py @@ -6,19 +6,6 @@ class TestNewOnesConverter(AccTestCase): - def test_newone(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5), dtype=torch.float16) - - inputs = [torch.randn(1, 10)] - self.run_test( - TestModule(), - inputs, - expected_ops={acc_ops.new_ones}, - test_implicit_batch_dim=False, - ) - def test_newone_no_dtype(self): class TestModule(nn.Module): def forward(self, x): @@ -47,23 +34,6 @@ def forward(self, x): class TestNewOnesConverterWithDynamicShape(AccTestCase): - def test_newone(self): - class TestModule(nn.Module): - def forward(self, x): - return x.new_ones((3, 5), dtype=torch.float16) - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.float32, - shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], - ), - ] - - self.run_test_with_dynamic_shape( - TestModule(), input_specs, expected_ops={acc_ops.new_ones} - ) - def test_newone_no_dtype(self): class TestModule(nn.Module): def forward(self, x): diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py index 67a07d83cf..0dee730b01 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_to_dtype.py @@ -271,47 +271,48 @@ def forward(self, x): precision=LowerPrecision.FP16, ) - # tensor.int() - def test_int(self): - class To(torch.nn.Module): - def forward(self, x): - x = x.int() - # we do not expect int to be output type, so add an extra layer - x = x.float() - return x - - input = torch.randn(2, 2) - inputs = [ - input, - ] - self.run_test( - To(), - inputs, - expected_ops={acc_ops.to_dtype}, - test_implicit_batch_dim=False, - precision=LowerPrecision.FP32, - ) - - # tensor.int() - def test_int_with_dynamic_shape_four_dimensions(self): - class To(torch.nn.Module): - def forward(self, x): - x = x.int() - # we do not expect int to be output type, so add an extra layer - x = x.float() - return x - - input_specs = [ - InputTensorSpec( - shape=(-1, -1, -1, -1), - dtype=torch.int, - shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], - ), - ] - - self.run_test_with_dynamic_shape( - To(), input_specs, expected_ops={acc_ops.to_dtype} - ) + # TODO Open in future. TRT 8.5 does not work for this test + # The test is a rare case. We need to remove it in graph maybe. + # def test_int(self): + # class To(torch.nn.Module): + # def forward(self, x): + # x = x.int() + # # we do not expect int to be output type, so add an extra layer + # x = x.float() + # return x + + # input = torch.randn(2, 2) + # inputs = [ + # input, + # ] + # self.run_test( + # To(), + # inputs, + # expected_ops={acc_ops.to_dtype}, + # test_implicit_batch_dim=False, + # precision=LowerPrecision.FP32, + # ) + + # # tensor.int() + # def test_int_with_dynamic_shape_four_dimensions(self): + # class To(torch.nn.Module): + # def forward(self, x): + # x = x.int() + # # we do not expect int to be output type, so add an extra layer + # x = x.float() + # return x + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, -1, -1, -1), + # dtype=torch.int, + # shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (3, 3, 3, 3))], + # ), + # ] + + # self.run_test_with_dynamic_shape( + # To(), input_specs, expected_ops={acc_ops.to_dtype} + # ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py index 7fad26dc84..798948537a 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_unary_ops.py @@ -9,29 +9,31 @@ from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec unary_ops = [ - (torch.sin, acc_ops.sin), - (torch.cos, acc_ops.cos), - (torch.tan, acc_ops.tan), - (torch.sinh, acc_ops.sinh), - (torch.cosh, acc_ops.cosh), - (torch.asin, acc_ops.asin), - (torch.acos, acc_ops.acos), - (torch.atan, acc_ops.atan), - (torch.abs, acc_ops.abs), - (torch.neg, acc_ops.neg), - (torch.reciprocal, acc_ops.reciprocal), - (torch.sqrt, acc_ops.sqrt), - (torch.log, acc_ops.log), - (torch.exp, acc_ops.exp), - (torch.floor, acc_ops.floor), - (torch.ceil, acc_ops.ceil), - (torch.sign, acc_ops.sign), + (torch.sin, acc_ops.sin, False), + (torch.cos, acc_ops.cos, False), + (torch.tan, acc_ops.tan, False), + (torch.sinh, acc_ops.sinh, False), + (torch.cosh, acc_ops.cosh, False), + (torch.asin, acc_ops.asin, True), + (torch.acos, acc_ops.acos, True), + (torch.atan, acc_ops.atan, True), + (torch.abs, acc_ops.abs, False), + (torch.neg, acc_ops.neg, False), + (torch.reciprocal, acc_ops.reciprocal, False), + (torch.sqrt, acc_ops.sqrt, False), + (torch.log, acc_ops.log, False), + (torch.exp, acc_ops.exp, False), + (torch.floor, acc_ops.floor, False), + (torch.ceil, acc_ops.ceil, False), + (torch.sign, acc_ops.sign, False), ] class TestUnaryOpConverters(AccTestCase): - @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops]) - def test_unary_ops(self, name, orig_op: Callable, expected_op): + @parameterized.expand([(op[1].__name__, op[0], op[1], op[2]) for op in unary_ops]) + def test_unary_ops( + self, name, orig_op: Callable, expected_op: Callable, range_req: bool + ): class TestModule(nn.Module): def __init__(self, orig_op): super().__init__() @@ -41,11 +43,15 @@ def forward(self, x): return self.orig_op(x) m = TestModule(orig_op) - inputs = [torch.randn(2, 2, 3)] + inputs = ( + [torch.distributions.uniform.Uniform(-1, 1).sample([2, 2, 3])] + if range_req + else [torch.randn(2, 2, 3)] + ) self.run_test(m, inputs, expected_ops={expected_op}) -class TestUnaryOpConvertersWithDynamicShapeFourDimensions(AccTestCase): +class TestUnaryVOpConvertersWithDynamicShapeFourDimensions(AccTestCase): @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in unary_ops]) def test_unary_ops(self, name, orig_op: Callable, expected_op): class TestModule(nn.Module): diff --git a/py/torch_tensorrt/fx/test/core/test_trt_module.py b/py/torch_tensorrt/fx/test/core/test_trt_module.py index b4fdbd4cbc..71855e1299 100644 --- a/py/torch_tensorrt/fx/test/core/test_trt_module.py +++ b/py/torch_tensorrt/fx/test/core/test_trt_module.py @@ -30,7 +30,7 @@ def forward(self, x): torch.save(trt_mod, "trt.pt") reload_trt_mod = torch.load("trt.pt") - torch.testing.assert_allclose( + torch.testing.assert_close( reload_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 ) os.remove(f"{os.getcwd()}/trt.pt") @@ -52,7 +52,7 @@ def forward(self, x): new_trt_mod = TRTModule() new_trt_mod.load_state_dict(st) - torch.testing.assert_allclose( + torch.testing.assert_close( new_trt_mod(inputs[0].cuda()).cpu(), ref_output, rtol=1e-04, atol=1e-04 ) diff --git a/py/torch_tensorrt/fx/test/passes/test_setitem.py b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py similarity index 87% rename from py/torch_tensorrt/fx/test/passes/test_setitem.py rename to py/torch_tensorrt/fx/test/passes/test_setitem_trt.py index 357d15be30..8f9c1a887f 100644 --- a/py/torch_tensorrt/fx/test/passes/test_setitem.py +++ b/py/torch_tensorrt/fx/test/passes/test_setitem_trt.py @@ -1,10 +1,10 @@ import torch -import torchdynamo +import torch._dynamo as torchdynamo from parameterized import parameterized +from torch._dynamo.optimizations import backends from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase -from torchdynamo.optimizations import backends class TestTransformSetitem(AccTestCase): @@ -24,13 +24,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) + )(m) - with optimize_ctx: - m(*inputs) + optimize_mod(*inputs) def test_setitem1d_c2(self): class TestModule(torch.nn.Module): @@ -49,13 +48,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) + )(m) - with optimize_ctx: - m(*inputs) + optimize_mod(*inputs) def test_setitem1d_c3(self): class TestModule(torch.nn.Module): @@ -73,13 +71,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) + )(m) - with optimize_ctx: - m(*inputs) + optimize_mod(*inputs) @parameterized.expand( [ @@ -106,12 +103,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -138,12 +135,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -169,12 +166,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -202,12 +199,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -235,12 +232,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -270,12 +267,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -306,12 +303,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -339,12 +336,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -374,12 +371,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -409,12 +406,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -456,12 +453,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) @parameterized.expand( [ @@ -502,12 +499,12 @@ def transform_fx(gm, example_inputs): gm = transform_setitem(gm, example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimize_mod = torchdynamo.optimize( transform_fx, nopython=True, - ) - with optimize_ctx: - m(*inputs) + )(m) + + optimize_mod(*inputs) # test with torchdynamo def test_setitem1d_trt(self): @@ -526,9 +523,9 @@ def forward(self, x, y): m.cuda() ref_output = m(*inputs) - optimize_ctx = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True) - with optimize_ctx: - output = m(*inputs) + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + + output = optimize_mod(*inputs) self.assertTrue(torch.allclose(ref_output, output)) @parameterized.expand( @@ -553,9 +550,8 @@ def forward(self, x, y): m.cuda() ref_output = m(*inputs) - optimize_ctx = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True) - with optimize_ctx: - output = m(*inputs) + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + output = optimize_mod(*inputs) self.assertTrue(torch.allclose(ref_output, output)) @parameterized.expand( @@ -595,9 +591,8 @@ def forward(self, x, y): m.cuda() ref_output = m(*inputs) - optimize_ctx = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True) - with optimize_ctx: - output = m(*inputs) + optimize_mod = torchdynamo.optimize(backends.fx2trt_compiler, nopython=True)(m) + output = optimize_mod(*inputs) self.assertTrue(torch.allclose(ref_output, output)) diff --git a/py/torch_tensorrt/fx/test/tools/test_model_packager.py b/py/torch_tensorrt/fx/test/tools/test_model_packager.py index b0ef521d27..fbc9403a95 100644 --- a/py/torch_tensorrt/fx/test/tools/test_model_packager.py +++ b/py/torch_tensorrt/fx/test/tools/test_model_packager.py @@ -51,6 +51,6 @@ def test_package_model(self): reload_model = pi.load_pickle("repro", "model") reload_inputs = pi.load_pickle("repro", "inputs") - torch.testing.assert_allclose(model(*inputs), reload_model(*reload_inputs)) + torch.testing.assert_close(model(*inputs), reload_model(*reload_inputs)) keys = dict(reload_model.named_children()).keys() self.assertEqual(keys, {"_holder"}) diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 23b7329669..8709661dfc 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2,7 +2,7 @@ import logging import operator import unittest -from typing import Callable, Dict, List, NamedTuple +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import numpy as np import torch @@ -38,12 +38,12 @@ def _make_model_unit_test( input = torch.randn(input_shape) traced = acc_tracer.trace(model, [input]) if enable_allclose: - torch.testing.assert_allclose(model(input), traced(input)) + torch.testing.assert_close(model(input), traced(input)) else: self.assertTrue(torch.equal(model(input), traced(input))) traced_again = acc_tracer.trace(traced, [input]) if enable_allclose: - torch.testing.assert_allclose(model(input), traced_again(input)) + torch.testing.assert_close(model(input), traced_again(input)) else: self.assertTrue(torch.equal(model(input), traced_again(input))) @@ -109,10 +109,10 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: ref_outputs, outputs, outputs_again ): if enable_allclose: - torch.testing.assert_allclose( + torch.testing.assert_close( torch.nan_to_num(ref_output), torch.nan_to_num(output) ) - torch.testing.assert_allclose( + torch.testing.assert_close( torch.nan_to_num(ref_output), torch.nan_to_num(output_again) ) else: @@ -1515,7 +1515,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ref = m(x) res = traced(x) - torch.testing.assert_allclose(ref, res) + torch.testing.assert_close(ref, res) def test_add_with_alpha(self): """ @@ -1969,7 +1969,7 @@ def forward( else: self.fail(f"Unexpected node: {node.format_node()}") - torch.testing.assert_allclose(m(input, a, b), traced(input, a, b)) + torch.testing.assert_close(m(input, a, b), traced(input, a, b)) def test_log1p(self): class TestModule(torch.nn.Module): @@ -1999,7 +1999,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: else: self.fail(f"Unexpected node: {node.format_node()}") - torch.testing.assert_allclose(m(input), traced(input)) + torch.testing.assert_close(m(input), traced(input)) @parameterized.expand([(torch.float,), (torch.float16,)]) def test_addmm(self, dtype): @@ -2275,6 +2275,36 @@ def forward(self, a: Dict[str, torch.Tensor]) -> torch.Tensor: self.assertTrue(torch.equal(m(input), traced(input))) + def test_none_type_ret(self): + """ + Test that a NoneType is traced as expected. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward( + self, a: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + return a + a, None + + m = TestModule() + input = torch.randn(1, 2, 3) + try: + traced = acc_tracer.trace( + m, + [input], + ) + except RuntimeError as e: + self.assertEqual( + "This error should not be triggered, as NoneType should be lowered without an issue", + str(e), + ) + ans1, _ = m(input) + ans2, _ = traced(input) + self.assertTrue(torch.equal(ans1, ans2)) + def test_mobilenet_v3(self): """ Test that we can trace mobilenet v3 small and run/compare against the untraced version. @@ -2615,6 +2645,7 @@ def test_all_acc_ops_registered(self): acc_ops.sign, acc_ops.permute, acc_ops.matmul, + # acc_ops.roi_align, acc_ops.quantize_per_tensor, acc_ops.quantize_per_channel, acc_ops.quantized_add, diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index 5f02051166..bb515252ed 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -2,16 +2,16 @@ import unittest import torch -import torchdynamo +import torch._dynamo as torchdynamo import torchvision from functorch.experimental import functionalize +from torch._dynamo.optimizations import backends +from torch._dynamo.optimizations.normalize import normalize_ir from torch.library import Library from torch_tensorrt.fx.lower import compile from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace -from torchdynamo.optimizations import backends -from torchdynamo.optimizations.normalize import normalize_ir torch.manual_seed(0) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index 51763a4b70..2aa82b1d72 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -95,7 +95,11 @@ def run_test( if not isinstance(ref, torch.Tensor): ref = torch.tensor([ref]) ref = ref.cpu() # to_dtype test has cases with gpu output - torch.testing.assert_allclose(out.cpu(), ref, rtol=rtol, atol=atol) + if ref.dtype == torch.int64: + ref = ref.int() # convert torch.max's index output tensor to int32 + torch.testing.assert_close( + out.cpu(), ref, rtol=rtol, atol=atol, equal_nan=True + ) def run_test_custom_compare_results( self, diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 1bdc7bf704..86b05b0d9e 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -424,7 +424,7 @@ def add(*, input, other): @register_acc_op_mapping(op_and_target=("call_method", "unsqueeze")) @register_acc_op_mapping(op_and_target=("call_function", torch.unsqueeze)) @register_acc_op -def unsqueeze(*, input, dim): +def unsqueeze(*, input, dim: int): return torch.unsqueeze(input=input, dim=dim) @@ -590,6 +590,7 @@ def clamp(*, input, min=None, max=None): return torch.clamp(input=input, min=min, max=max) +@register_acc_op_mapping(op_and_target=("call_function", torch.concat)) @register_acc_op_mapping(op_and_target=("call_function", torch.cat)) @register_acc_op def cat(*, tensors, dim): @@ -2725,6 +2726,43 @@ def packed_quantized_conv2d_mapper( return new_node +# @register_acc_op +# @register_acc_op_mapping( +# op_and_target=("call_function", torch.ops._caffe2.RoIAlign), +# arg_replacement_tuples=[ +# ("features", "features"), +# ("rois", "rois"), +# ("order", "order"), +# ("spatial_scale", "spatial_scale"), +# ("pooled_h", "pooled_h"), +# ("pooled_w", "pooled_w"), +# ("sampling_ratio", "sampling_ratio"), +# ("aligned", "aligned"), +# ], +# ) +# def roi_align( +# *, +# features, +# rois, +# order, +# spatial_scale, +# pooled_h, +# pooled_w, +# sampling_ratio, +# aligned, +# ): +# return torch.ops._caffe2.RoIAlign( +# features=features, +# rois=rois, +# order=order, +# spatial_scale=spatial_scale, +# pooled_h=pooled_h, +# pooled_w=pooled_w, +# sampling_ratio=sampling_ratio, +# aligned=aligned, +# ) + + @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.ops.quantized.add_relu), arg_replacement_tuples=[ @@ -2786,9 +2824,16 @@ def packed_quantized_convrelu2d_mapper( @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.gelu)) @register_acc_op_mapping(op_and_target=("call_method", "gelu")) +@register_custom_acc_mapper_fn( + op_and_target=("call_module", torch.nn.GELU), + arg_replacement_tuples=[ + ("input", "input"), + ("approximate", "approximate"), + ], +) @register_acc_op -def gelu(*, input): - return torch.nn.functional.gelu(input=input) +def gelu(*, input, approximate="none"): + return torch.nn.functional.gelu(input=input, approximate=approximate) @register_acc_op_properties(AccOpProperty.unary) @@ -3075,6 +3120,108 @@ def log_softmax_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node return log_node +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.nn.functional.softplus), + arg_replacement_tuples=[ + ("input", "input"), + ("beta", "beta", this_arg_is_optional), + ("threshold", "threshold", this_arg_is_optional), + ], +) +def softplus_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + """ + Maps torch.nn.functional.softplus to acc_ops.where, acc_ops.relu, acc_ops.exp, acc_ops.mul, acc_ops.add and acc_ops.div + + softplus(input, beta, threshold) = where(beta * input > threshold, relu(input), div(log(1 + exp(beta * input))), beta)) + + torch.where( + softplus_module.beta * sample_inputs[0] > softplus_module.threshold, + sample_inputs[0].relu(), + torch.div((1 + (softplus_module.beta * sample_inputs[0]).exp()).log(), softplus_module.beta), + ) + + """ + + input_node = node.kwargs["input"] + beta_node = node.kwargs["beta"] + threshold_node = node.kwargs["threshold"] + + with node.graph.inserting_after(node): + cond_mul_node = node.graph.call_function( + mul, + kwargs={ + "input": input_node, + "other": beta_node, + }, + ) + cond_mul_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(cond_mul_node): + gt_node = node.graph.call_function( + gt, + kwargs={ + "input": cond_mul_node, + "other": threshold_node, + }, + ) + gt_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(gt_node): + relu_node = node.graph.call_function(relu, kwargs={"input": input_node}) + relu_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(relu_node): + mul_node = node.graph.call_function( + mul, + kwargs={ + "input": input_node, + "other": beta_node, + }, + ) + mul_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(mul_node): + exp_node = node.graph.call_function(exp, kwargs={"input": mul_node}) + exp_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(exp_node): + add_node = node.graph.call_function( + add, + kwargs={ + "input": exp_node, + "other": 1, + }, + ) + add_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(add_node): + log_node = node.graph.call_function(log, kwargs={"input": add_node}) + log_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(log_node): + div_node = node.graph.call_function( + div, + kwargs={ + "input": log_node, + "other": beta_node, + }, + ) + div_node.meta = input_node.meta.copy() + + with node.graph.inserting_after(div_node): + where_node = node.graph.call_function( + where, + kwargs={ + "condition": gt_node, + "x": relu_node, + "y": div_node, + }, + ) + where_node.meta = div_node.meta.copy() + + return where_node + + @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.baddbmm), arg_replacement_tuples=[ diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index 61ade62e6c..af7e27aa09 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -287,6 +287,7 @@ class AccRewritingTracer(Tracer): torch.nn.intrinsic.quantized.ConvReLU2d, jit.ScriptModule, jit.RecursiveScriptModule, + torch.nn.modules.activation.MultiheadAttention, } def is_leaf_module(self, m: nn.Module, mod_qual_name: str) -> bool: diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index 5d9d27be9c..ab3207925f 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -176,6 +176,8 @@ def map_tensor_metadata(a: Any, fn: Callable): """ if isinstance(a, int): return 1 + elif a is None: + return 1 elif isinstance(a, TensorMetadata): return fn(a) elif isinstance(a, tuple):