diff --git a/.circleci/config.yml b/.circleci/config.yml index d9aa118955..347dd77294 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -263,7 +263,7 @@ commands: parameters: torch-build: type: string - default: "2.0.0.dev20230120+cu117" + default: "2.0.0.dev20230129+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" @@ -1026,7 +1026,7 @@ parameters: # Nightly platform config torch-build: type: string - default: "2.0.0.dev20230120+cu117" + default: "2.0.0.dev20230129+cu117" torch-build-index: type: string default: "https://download.pytorch.org/whl/nightly/cu117" diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index c79f618be3..943eb203b3 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -298,8 +298,6 @@ def aten_ops_sub( return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) -@tensorrt_converter(torch.ops.aten._unsafe_view.default) -@tensorrt_converter(torch.ops.aten._reshape_alias.default) @tensorrt_converter(torch.ops.aten.view.default) def aten_ops_reshape( network: TRTNetwork, @@ -308,11 +306,33 @@ def aten_ops_reshape( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - kwargs_new = { - "input": args[0], - "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]), - } - return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name) + input_val = args[0] + # for case where input_val is TRTensor + input_val = get_trt_tensor(network, input_val, f"{name}_input_val") + shape = args[1] + + layer = network.add_shuffle(input_val) + + if all(isinstance(s, int) for s in shape): + layer.reshape_dims = tuple(shape) + else: + # Convert all the dimensions to trt Tensors. + trt_shape = [] + + for i, s in enumerate(shape): + if isinstance(s, TRTTensor): + trt_shape.append(s) + else: + a = get_trt_tensor(network, s, f"{name}_{i}") + trt_shape.append(a) + + shape_layer = network.add_concatenation(inputs=trt_shape) + shape_layer.axis = 0 + shape_layer.name = f"{name}_output_shape" + layer.set_input(1, shape_layer.get_output(0)) + + set_layer_name(layer, target, name) + return layer.get_output(0) @tensorrt_converter(torch.ops.aten.cat.default) @@ -345,3 +365,104 @@ def aten_ops_expand( return acc_ops_converters.acc_ops_expand_tensor( network, target, None, kwargs_new, name ) + + +@tensorrt_converter(operator.floordiv) +def aten_ops_operator_floordiv( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + + +@tensorrt_converter(operator.mul) +def aten_ops_operator_mul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + + +@tensorrt_converter(operator.add) +def aten_ops_operator_add( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + + +@tensorrt_converter(operator.sub) +def aten_ops_operator_sub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.sym_numel) +def aten_ops_sym_numel( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + shape_layer = network.add_shape(args[0]) + set_layer_name(shape_layer, target, "_shape_layer") + reduce_layer = network.add_reduce( + shape_layer.get_output(0), + trt.ReduceOperation.PROD, + axes=get_axes_for_reduce_op(0, False), + keep_dims=True, + ) + set_layer_name(reduce_layer, target, "_reduce_layer") + return reduce_layer.get_output(0) + + +@tensorrt_converter(torch.ops.aten.sym_size) +def aten_ops_sym_size( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + shape_layer = network.add_shape(args[0]) + ind = args[1] + set_layer_name(shape_layer, target, "_shape_layer") + slice_layer = network.add_slice( + input=shape_layer.get_output(0), + start=[ind], + shape=[1], + stride=[1], + ) + set_layer_name(slice_layer, target, "_slice_layer") + return slice_layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/convolution.py b/py/torch_tensorrt/fx/converters/convolution.py index 7d8eb1589e..6af940200a 100644 --- a/py/torch_tensorrt/fx/converters/convolution.py +++ b/py/torch_tensorrt/fx/converters/convolution.py @@ -1,8 +1,9 @@ # @manual=//deeplearning/trt/python:py_tensorrt +import logging + import numpy as np import tensorrt as trt import torch -import logging from ..converter_registry import tensorrt_converter diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 2541143fb6..f96f1db6b9 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -7,6 +7,7 @@ import torch import torch.fx as fx import torch.nn as nn +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.passes.splitter_base import SplitResult from .fx2trt import TRTInterpreter, TRTInterpreterResult @@ -18,8 +19,7 @@ from .tracer.acc_tracer import acc_tracer from .trt_module import TRTModule -from .utils import LowerPrecision, proxytensor_trace - +from .utils import LowerPrecision logger = logging.getLogger(__name__) @@ -259,7 +259,9 @@ def create( return cls( lower_pass_manager_builder=LowerPassManagerBuilder( lower_setting=lower_setting, - trace_func=lambda module, inputs: proxytensor_trace(module, inputs), + trace_func=lambda module, inputs: aten_tracer.opt_trace( + module, inputs + ), split_func=split_func, lower_func=default_lower_pass(interpreter_builder), ) @@ -308,14 +310,6 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module: pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( inputs, additional_inputs ) - if lower_setting.is_aten: - pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( - inputs, additional_inputs - ) - else: - pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( - inputs, additional_inputs - ) lower_result = pm(module) return lower_result diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 514a52fab8..61052b21af 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -127,6 +127,31 @@ def graph_optimization_pass(self) -> PassManager: return PassManager.build_from_passlist(passes) + def graph_optimization_pass_aten(self) -> PassManager: + passes = [] + + for p in self.lower_setting.customized_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + for p in self.lower_setting.lower_basic_fuse_pass.passes: + passes.append(wrapper(p, self._input)) + # TODO fix this pass for aten graph + # if ( + # hasattr(self.lower_setting, "lower_precision") + # and self.lower_setting.lower_precision is LowerPrecision.FP16 + # ) or ( + # hasattr(self.lower_setting, "precision") + # and self.lower_setting.precision is LowerPrecision.FP16 + # ): + # passes.append(wrapper(fix_clamp_numerical_limits_to_fp16, self._input)) + + passes.append( + inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) + ) + # TODO we most likely do not need it for aten + # passes.append(fix_reshape_batch_dim) + + return PassManager.build_from_passlist(passes) + def _split_pass(self) -> PassManager: passes = [ partial( @@ -259,8 +284,7 @@ def build_aten2trt_lower_pipeline( passes.append( wrapper(self._trace_func, self._input), ) - passes.append(self._default_replace_mutable_op_pass()) - passes.append(self.graph_optimization_pass()) + passes.append(self.graph_optimization_pass_aten()) passes.append(self._split_pass()) passes.append(self._trt_lower_pass()) diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py index fac55ad46a..97c5251da9 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py @@ -73,7 +73,7 @@ def forward(self, x): # param("ceil_mode", 1, ceil_mode=True), ] ) - @unittest.skip("PT tracer issue") + @unittest.skip("PT2 tracer issue") def test_max_pool3d( self, test_name, @@ -95,6 +95,7 @@ def forward(self, x): inputs = [torch.randn(1, 3, 32, 32, 32)] self.run_test(TestModule(), inputs, expected_ops={}) + @unittest.skip("PT2 tracer issue") def test_max_pool3d_with_dynamic_shape(self): class TestModule(torch.nn.Module): def __init__(self): @@ -118,7 +119,7 @@ def forward(self, x): @parameterized.expand( [ ("default", 1), - param("stride", 2, stride=()), + # param("stride", 2, stride=()), #PT2 tracer issue ] ) def test_stride_none_max_pool2d( @@ -147,7 +148,7 @@ def forward(self, x): param("stride", 2, stride=()), ] ) - @unittest.skip("PT tracer issue") + @unittest.skip("PT2 tracer issue") def test_stride_none_max_pool3d( self, test_name, @@ -209,6 +210,7 @@ def forward(self, x): param("stride", 2, stride=()), ] ) + @unittest.skip("PT2 tracer issue") def test_stride_none_max_pool3d_with_dynamic_shape( self, test_name, diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py index 96c8fe7423..538e575d6e 100644 --- a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -1,7 +1,7 @@ import unittest +import tensorrt as trt import torch -import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec @@ -14,7 +14,10 @@ class TestReshapeConverter(DispatchTestCase): ((1, 10, -1),), ] ) - @unittest.skip("Need support") + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) def test_reshape(self, target_shape): class TestModule(torch.nn.Module): def __init__(self, target_shape): @@ -31,58 +34,68 @@ def forward(self, x): expected_ops={torch.ops.aten.view.default}, ) - ## TODO: proxytensor tracer does not support output size containing -1. If dim=0 is set to -1 for dynamic batch, - ## then it is becomes fixed acoording to the input. For ex. input (-1, 2, 3), output size (-1, 6), then - ## proxytensor tracer output is (32, 6) if sample input is (32, 2, 3). But fx tracer could keep the output size as (-1, 6) - # @parameterized.expand( - # [ - # ((-1, 2),), - # ((1, 2, -1),), - # ] - # ) - # def test_reshape_with_dynamic_shape(self, target_shape): - # class TestModule(torch.nn.Module): - # def __init__(self, target_shape): - # super().__init__() - # self.target_shape = target_shape + @parameterized.expand( + [ + ((-1, 10),), + ((-1, 5),), + ((2, 2, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape - # def forward(self, x): - # return torch.reshape(x, self.target_shape) + def forward(self, x): + return torch.reshape(x, self.target_shape) - # input_specs = [ - # InputTensorSpec( - # shape=(-1, -1, -1), - # dtype=torch.float32, - # shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], - # ), - # ] - # self.run_test_with_dynamic_shape( - # TestModule(target_shape), input_specs, expected_ops={torch.ops.aten._reshape_alias.default} - # ) + input_specs = [ + InputTensorSpec( + shape=(-1, 2, 5), + dtype=torch.float32, + shape_ranges=[((1, 2, 5), (10, 2, 5), (10, 2, 5))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(target_shape), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) - # def test_reshape_with_dynamic_shape_size(self): - # class TestModule(torch.nn.Module): - # def forward(self, x, y): - # shape_y = y.shape - # t = shape_y[1] - # return torch.reshape(x, [-1, t, 3]) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) - # input_specs = [ - # InputTensorSpec( - # shape=(-1, 5, 6), - # dtype=torch.float32, - # shape_ranges=[((1, 5, 6), (2, 5, 6), (3, 5, 6))], - # ), - # InputTensorSpec( - # shape=(-1, 5), - # dtype=torch.float32, - # shape_ranges=[((1, 5), (1, 5), (3, 5))], - # ), - # ] + input_specs = [ + InputTensorSpec( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))], + ), + InputTensorSpec( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (3, 5), (3, 5))], + ), + ] - # self.run_test_with_dynamic_shape( - # TestModule(), input_specs, expected_ops={acc_ops.reshape} - # ) + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py index 32386185a6..eeeb9c9eeb 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py +++ b/py/torch_tensorrt/fx/test/passes/test_fuse_permute_linear_trt.py @@ -1,8 +1,9 @@ # Owner(s): ["oncall: gpu_enablement"] +import unittest + import torch import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops -import unittest from torch.testing._internal.common_utils import run_tests from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 709df1cd2f..633359127f 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2692,7 +2692,6 @@ def test_all_acc_ops_registered(self): acc_ops.sign, acc_ops.permute, acc_ops.matmul, - # acc_ops.roi_align, acc_ops.quantize_per_tensor, acc_ops.quantize_per_channel, acc_ops.quantized_add, diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index 39a1a46a34..e160626cf2 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -65,118 +65,21 @@ def forward(self, x, y): ref_output = mod(*inputs_new) torch.testing.assert_close(output, ref_output) - def test_simple(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU(inplace=True) - - def forward(self, x, y): - y = y + x - y = y.mul(x) - y = y + x - y = y + x - y = y / x - y = y + x - y = y + x - y = y / x - y = y + x - y = self.relu(y) - return y - - mod = TestModule() - mod = mod.cuda().half().eval() - - def f(x, y): - return mod(x, y) - - inputs = [torch.randn(2, 5), torch.ones(2, 5)] - inputs = [i.cuda().half() for i in inputs] - ref_output = f(*inputs) - - mod = compile( - mod, - inputs, - max_batch_size=100, - explicit_batch_dimension=True, - lower_precision=LowerPrecision.FP16, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - dynamic_batch=True, - is_aten=True, - ) - output = mod(*inputs) - torch.testing.assert_close(output, ref_output) - - def test_resnet18_aten(self): - mod = torchvision.models.resnet18() - mod = mod.cuda().half().eval() - - inputs = [torch.ones(32, 3, 224, 224)] - inputs = [i.cuda().half() for i in inputs] - - aten_mod = compile( - mod, - inputs, - max_batch_size=32, - explicit_batch_dimension=True, - lower_precision=LowerPrecision.FP16, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - dynamic_batch=False, - is_aten=True, - ) - aten_output = aten_mod(*inputs) - fx_mod = compile( - mod, - inputs, - max_batch_size=32, - explicit_batch_dimension=True, - lower_precision=LowerPrecision.FP16, - verbose_log=False, - timing_cache_prefix="", - save_timing_cache=False, - cuda_graph_batch_size=-1, - dynamic_batch=False, - is_aten=False, - ) - fx_output = fx_mod(*inputs) - # Kernel selection is tricky in TRT with big variance as shown below: - # Mismatched elements: 30816 / 32000 (96.3%) - # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) - # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) - # so we choose to use cosine similarity - cos_val = torch.nn.functional.cosine_similarity( - aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 - ) - self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) - def test_resnet18_dynamo(self): mod = torchvision.models.resnet18() mod = mod.cuda().half().eval() - def f(x): - return mod(x) - inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] - torchdynamo.reset() - dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) - dynamo_aten_output = dynamo_aten_mod(*inputs) + ref_output = mod(*inputs) torchdynamo.reset() - dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) dynamo_output = dynamo_mod(*inputs) - - cos = torch.nn.CosineSimilarity(dim=0, eps=1e-4) - cos_val = cos(dynamo_output.flatten(), dynamo_aten_output.flatten()) - - self.assertTrue(cos_val.cpu().numpy() > 0.999) + cos_val = torch.nn.functional.cosine_similarity( + dynamo_output.flatten(), ref_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) class DispatchTracerTest(unittest.TestCase): diff --git a/py/torch_tensorrt/fx/test/tracer/test_resnet.py b/py/torch_tensorrt/fx/test/tracer/test_resnet.py new file mode 100644 index 0000000000..1103c8f623 --- /dev/null +++ b/py/torch_tensorrt/fx/test/tracer/test_resnet.py @@ -0,0 +1,98 @@ +import unittest + +import torch + +import torch._dynamo.config +import torchvision +from torch_tensorrt.fx.lower import compile +from torch_tensorrt.fx.utils import LowerPrecision + + +class ResnetTest(unittest.TestCase): + def test_resnet18_aten(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + + aten_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + aten_output = aten_output[0] + fx_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + # Kernel selection is tricky in TRT with big variance as shown below: + # Mismatched elements: 30816 / 32000 (96.3%) + # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) + # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) + # so we choose to use cosine similarity + cos_val = torch.nn.functional.cosine_similarity( + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) + + def test_resnet18_aten_dynamic(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + + aten_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=True, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + aten_output = aten_output[0] + fx_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=True, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + + cos_val = torch.nn.functional.cosine_similarity( + aten_output.flatten(), fx_output.flatten(), dim=0, eps=1e-4 + ) + self.assertTrue(cos_val.detach().cpu().numpy() > 0.999) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index bd22e8bb4e..30d6dc96c9 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -353,7 +353,6 @@ def generate_graph( # Combine with customized passes specific to any model if customized_passes: passes_list.extend(customized_passes) - fx_module, _ = aten_tracer.trace(mod, original_inputs) for passes in passes_list: pr: PassResult = passes(fx_module) @@ -361,7 +360,7 @@ def generate_graph( fx_module(*original_inputs) fx_module = run_const_fold(fx_module) - print(fx_module.graph) + _LOGGER.info(f"FX graph= {fx_module.graph}") if len(expected_ops): self.assert_has_op(fx_module, expected_ops) @@ -429,10 +428,10 @@ def run_test_with_dynamic_shape( rtol=1e-03, atol=1e-03, ): - + mod.eval() inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) - mod = proxytensor_trace(mod, inputs) interp = TRTInterpreter( mod, input_specs, diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index bf693114fc..8abad9c509 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -2743,43 +2743,6 @@ def packed_quantized_conv2d_mapper( return new_node -# @register_acc_op -# @register_acc_op_mapping( -# op_and_target=("call_function", torch.ops._caffe2.RoIAlign), -# arg_replacement_tuples=[ -# ("features", "features"), -# ("rois", "rois"), -# ("order", "order"), -# ("spatial_scale", "spatial_scale"), -# ("pooled_h", "pooled_h"), -# ("pooled_w", "pooled_w"), -# ("sampling_ratio", "sampling_ratio"), -# ("aligned", "aligned"), -# ], -# ) -# def roi_align( -# *, -# features, -# rois, -# order, -# spatial_scale, -# pooled_h, -# pooled_w, -# sampling_ratio, -# aligned, -# ): -# return torch.ops._caffe2.RoIAlign( -# features=features, -# rois=rois, -# order=order, -# spatial_scale=spatial_scale, -# pooled_h=pooled_h, -# pooled_w=pooled_w, -# sampling_ratio=sampling_ratio, -# aligned=aligned, -# ) - - @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.ops.quantized.add_relu), arg_replacement_tuples=[ diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index 5d81dec6b0..b35c6958a8 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -6,6 +6,21 @@ import torch import torch._dynamo as torchdynamo from torch._dynamo.guards import Guard +from torch.fx.passes.infra.pass_base import PassResult + +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_inplace_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) from typing_extensions import TypeAlias Value: TypeAlias = Union[ @@ -112,3 +127,34 @@ def dynamo_trace( def trace(f, args, *rest): graph_module, guards = dynamo_trace(f, args, True, "symbolic") return graph_module, guards + + +def opt_trace(f, args, *rest): + """ + Optimized trace with necessary passes which re-compose some ops or replace some ops + These passes should be general and functional purpose + """ + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + replace_inplace_ops, # remove it once functionalization is enabled + ] + + fx_module, _ = trace(f, args) + print(fx_module.graph) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + + fx_module(*args) + + fx_module = run_const_fold(fx_module) + print(fx_module.graph) + return fx_module