From fba515cd369552b9dba58c7557b1e963298d5218 Mon Sep 17 00:00:00 2001 From: Wei Wei Date: Wed, 27 Jul 2022 09:36:32 -0700 Subject: [PATCH] Changes done internally at Facebook 6703b98dff0695d91026f057b951dba1355825fa Shreyansh Prajapati Test dynamic shape support for acc_ops.prod c822345d6d673e1653c2208435e34ab400bada3d Jason Park Add support for generic torch ops to be used in training. e5758602a0592d6c2b71d6d66a0398c4dd9b5e20 Shreyansh Prajapati Test dynamic shape support for repeat interleave c13c633f04df162500eed477c0569eb2b81eb070 Shreyansh Prajapati Test dynamic shape support for reduce ops 863476cf43b210922b88585b8f196dd84fbebb56 Shreyansh Prajapati Test dynamic shape support for acc_op.convolution 68dff39793e5c30c20010919a855bb3d984015d7 Ruichao Xiao [fbcode][GPU][DHEN]fuse split squeeze cat as reshape f8b920769507ebd2ff02419b4aece25451298a95 Ruichao Xiao [fbcode][DHEN][GPU] reorder and merge cats whose input is a sublist of another cat 5b6a8d2d6be979983a52ac96225fefb510c3817c Andrew Or [Quant][fx] Rename convert_to_reference to convert_to_reference_fx 996a0e080b8a8bc0b292a7c2ac92f41f6db33a2e Shreyansh Prajapati Test dynamic shape support for acc_op.expand 084631fe74b304fbb9481ca15fd452a3714fb1b8 Shreyansh Prajapati Test dynamic shape support for acc_op.to_dtype b3195e76329ccddbb5c4640cfa884d0e457d2d34 Shreyansh Prajapati Test dynamic shape support for std a5d964e62bdf769cf8c2e67321138b33e1f524a7 Shreyansh Prajapati Test dynamic shape support for acc_op.tile 3d33d45b2fc7f10f25c22946ba474b227e4b6529 Shreyansh Prajapati Test dynamic shape support for squeeze 09085abf63d7e7732e2cd66e600e8afc6d58964f Shreyansh Prajapati Test dynamic shape support for acc_op.topk 65edc7ea12899e9bd2af42c890a64de853d9b7fe Huamin Li temporarily skip gelu tests d11e521f9b90554ca86912a49920afa4406bb40d Shirong Wu Suppress accuracy check for remove_reshape_with_batch_size_change 6d948298b2327d229e010a34f1c221b11d2eb504 Ankur Singla [GPULowering] Suppress accuracy check for fuse_unsqueeze_cat_sum e780b647fc9571b77d9f41c963041a6ac3d66f33 Janet Yang Lower xrayvideo2022 to fx2trt 433c7207fef16b1fdff985546ea969c39fa83e7c generatedunixname89002005287564 [Codemod][Remove @noautodeps and @autodeps-skip tags] deeplearning/trt 1/2 66fdb65cffa925660c77b4758388399db3cbfe48 Scott Wolchok [fx2ait] Minor Python cleanup in acc_ops_getitem 188132ecb2c19bcbf83cb2dc381f6e3798629f87 generatedunixname89002005324833 [AutoAccept][Codemod][FBSourceBuckFormatLinter] Daily `arc lint --take BUCKFORMAT` 4536bae4686dd01f2149541ea7fb330e178a4969 Wei Wei [fx2trt] support sub 064602e666f86c110d931cd90a8536112a19b4ad Shreyansh Prajapati Test dynamic shape support for acc_ops.interpolate 9dfd0ee0cecb1975e3f53c44de237d67ca443ec5 Shreyansh Prajapati Test dynamic shape support for unary_ops 39b9efad8d5d82463a2016d135c0cf277de1c3c6 Shreyansh Prajapati Test dynamic shape support for unsqueeze 2bb17667d1dabc95391950426fc1f921eb3d0959 Shreyansh Prajapati Test dynamic shape support for acc_ops.split 64dfb7b096686cb2fd33197340dc72f30d525456 Shirong Wu Group LN trt plugin 438f670e28df59b0734baa092a514fba3d75eb4f Shreyansh Prajapati Test dynamic shape support for acc_ops.avgpool df0fe32dae4343827bd9b37b72daae761b02f228 Shreyansh Prajapati Test dynamic shape support for acc_ops masked fill 44fe735d3493ea2d05a56b49093e4a23dd63a98e Shreyansh Prajapati Test dynamic shaope support for acc_ops.pad 4f931acca706d8ce79045ceafef2ea0486609149 Wei Wei [fx2trt] torch.max dynamic shape test bf6f6cbe217d26a95ca9122574adf7de3966db9e Shreyansh Prajapati Change the name of the test from full_reduce to dim_reduce 1c5680ed107d9206f3514eff4069a3f6c870ba8c Shreyansh Prajapati Test dynamic shape support for acc_ops.type_as 33e4c175a4f5fec78ac0b1c8eb262ca777c7aaba Shreyansh Prajapati Test dynamic shape support for acc_ops.min f37be34bcef9716080b8bafbd1f4ad72e412c44c Wei Wei [fx2trt] plugin for grid_sample 57b5cc6a0f4839686ae360361a3a13b424794ee7 generatedunixname89002005367269 [AutoAccept][Codemod][FBSourceBlackLinter] Daily `arc lint --take BLACK` eb741cc5e5a7babdc94e72d411670905f54da3e0 Shreyansh Prajapati Updated the dynamic shape support for narrow op 521c36b96a14741ae89d7af6cbb658120bcec2ea Shreyansh Prajapati Removing the comment for 4 dims dynamic shape support after analysis e947343375967fe9efb0a16fdb9f63bff1449328 Shreyansh Prajapati Updated the pad test for dynamic batch for analysis 3d64087014e91bc301a315eae43683b1aa2b66bc Oleg Khabinov [trt_bc] Some improvements dfd937a56fa01aca88a89b46176befdac4c202c4 Shreyansh Prajapati Updated the test for as_strided op for analysis 11d76d0420dcaa4bb8890dcdeb86b6e534af831c Bangsheng Tang [gpu][infer] replace fx2trt_layer_norm with fbgemm layer_norm 932046ff6ea6dead114c0222b23ca3854690cffa Wei Wei [fx2trt] bridge the dynamic batch and fixed shape f911463393d8a671cfee6de6d1b5ef4d4f3991a6 Shirong Wu group swish LN plugin ea65970f23dd7a468e5bc43240f2a9bfa07c9b3b Shirong Wu Create backend specific lower pass 38183e4a724e5514db2be7193cf4897b59759252 Alex Beloi [fx] run acc_linter.lint in acc_tracer.trace 088abb6a790a62ca9f8515298a54117cc7fa31d4 Alex Beloi [fx] re-add pointwise property to acc_ops.clamp 9905c34f2bd28e9b64f10336f9ac326cc39eb60d Oleg Khabinov [trt] Comment out torch.ops.fbgemm dependency in TRT converters 8252e779476d2ff22ad78185af97a526b2f70fe3 Alex Beloi [fx] add operator test suite to test_acc_tracer.py 7b93a89c903bc0b6c59efb73a510c3dce8ef793a Shirong Wu Add option for lower and trt_splitter e08dabcbcd8c3e8ae92484e14cf07bb26993a8d6 Wei Wei [fx2trt] convert print to logging 3d61dc169b8a7dd1aecad35891a628e44e2c5a02 Shreyansh Prajapati Readme.md file for dynamic shape support --- py/torch_tensorrt/fx/Dynamic_Shape_Support.md | 137 ++++++++++++++++++ .../fx/converters/acc_ops_converters.py | 10 +- py/torch_tensorrt/fx/lower.py | 3 +- .../fx/passes/lower_pass_manager_builder.py | 13 +- py/torch_tensorrt/fx/passes/pass_utils.py | 2 +- .../fx/test/passes/test_graph_opts.py | 12 +- .../fx/test/tracer/test_acc_tracer.py | 8 +- .../fx/test/trt_lower/test_diagnostics.py | 10 +- .../fx/test/trt_lower/test_observer.py | 17 ++- .../fx/test/trt_lower/trt_splitter_test.py | 43 +++++- py/torch_tensorrt/fx/tools/common_fx2trt.py | 9 +- .../fx/tools/engine_layer_visualize.py | 9 +- py/torch_tensorrt/fx/tools/trt_minimizer.py | 5 +- .../fx/tools/trt_profiler_sorted.py | 5 +- py/torch_tensorrt/fx/tools/trt_splitter.py | 9 +- .../fx/tracer/acc_tracer/acc_normalizer.py | 6 +- .../fx/tracer/acc_tracer/acc_ops.py | 8 + .../fx/tracer/acc_tracer/acc_tracer.py | 4 +- .../fx/tracer/acc_tracer/acc_utils.py | 7 +- 19 files changed, 276 insertions(+), 41 deletions(-) create mode 100644 py/torch_tensorrt/fx/Dynamic_Shape_Support.md diff --git a/py/torch_tensorrt/fx/Dynamic_Shape_Support.md b/py/torch_tensorrt/fx/Dynamic_Shape_Support.md new file mode 100644 index 0000000000..eb4454340e --- /dev/null +++ b/py/torch_tensorrt/fx/Dynamic_Shape_Support.md @@ -0,0 +1,137 @@ +# PyTorch Operations Dynamic Shape Support Summary + + + + | Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason | +| --- | --- | --- | --- | --- | --- | +| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. | +| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 | +| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | | +| | avg_pool1d | partially | (-1, 3, 3) | 1 | | +| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." | +| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | | +| cat | | yes | (-1,-,1,-1,-1) | 4 | | +| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! | +| clamp | | yes | (-1,-,1,-1,-1) | | | +| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. | +| | conv1d | partially | (-1, 3, 3) | 1 | | +| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. | +| dequantize | | yes | (-1,-,1,-1,-1) | 4 | | +| eimsum | | yes | (-1,-,1,-1,-1) | 4 | | +| elu | | yes | (-1,-,1,-1,-1) | 4 | | +| embedding | | yes | (-1,-,1,-1,-1) | 4 | | +| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConstant | partially | (3,-1) | 1 | | +| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| expand | | no | | | Dynamic shape is not suitable for the expand operation. | +| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | | +| gelu | | yes | (-1,-,1,-1,-1) | 4 | | +| getitem | | yes | (-1,-,1,-1,-1) | 4 | | +| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | | +| interpolate | | yes | (-1,-,1,-1,-1) | 4 | | +| isinf | | yes | (-1,-,1,-1,-1) | 4 | | +| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | | +| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. | +| logical_and | | yes | (-1, -1, -1, -1) | 4 | | +| logical_or | | yes | (-1, -1, -1, -1) | 4 | | +| logical_xor | | yes | (-1, -1, -1, -1) | 4 | | +| lt | | yes | (-1, -1, -1, -1) | 4 | | +| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| mat_mul | | yes | batch dim | | | +| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | | +| maximum | | yes | (-1, -1, -1, -1) | 4 | | +| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer | +| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | | +| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | | +| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinMethod | yes | (-1, -1, -1, -1) | 4 | | +| minimum | | yes | (-1, -1, -1, -1) | 4 | | +| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! | +| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | | +| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConstantConverter | partially | (3, -1) | 1 | | +| new_ones | | yes | (-1, -1, -1, -1) | 4 | | +| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. | +| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. | +| permute | | yes | (-1, -1, -1, -1) | 4 | | +| prod | | yes | (-1, -1, -1, -1) | 4 | | +| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | | +| reduce op | | yes | (-1, -1, -1, -1) | 4 | | +| relu | | yes | (-1, -1, -1, -1) | 4 | | +| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| reshape | | yes | (-1, -1, -1, -1) | 4 | | +| selu | | yes | (-1, -1, -1, -1) | 4 | | +| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| silu | | yes | (-1,-,1,-1,-1) | 4 | | +| size | | yes | (-1, -1, -1, -1) | 4 | | +| softmax | | yes | (-1, -1, -1, -1) | 4 | | +| softsign | | yes | (-1, -1, -1, -1) | 4 | | +| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! | +| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. | +| std | | yes | (-1, -1, -1, -1) | 4 | | +| tanh | | yes | (-1, -1, -1, -1) | 4 | | +| tile | | yes | (-1, -1, -1, -1) | 4 | | +| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | | +| | float | yes | (-1, -1, -1, -1) | 4 | | +| topk | | yes | (-1, -1, -1, -1) | 4 | | +| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | | +| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | | +| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} | +| unary ops | | yes | (-1, -1, -1, -1) | 4 | | +| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] | + + + +Binary Ops Include following operations: +|Binary Ops | +|----------| +|add | +|sub | +|div | +|mul | +|floor_div | +|fmod | +|floor_divide| +|pow | + + +Unary Ops Include following operations: +|Unary Ops | +|----------| +|rsqrt | +|sin | +|cos | +|tan | +|sinh | +|cosh | +|asin | +|acos | +|atan | +|abs | +|neg | +|reciprocal| +|sqrt | +|log | +|exp | +|floor | +|ceil | +|sign | + +Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing. diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 95436e762e..68334ebe44 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1,4 +1,5 @@ # flake8: noqa +import logging import math import operator import warnings @@ -22,6 +23,9 @@ from .converter_utils import * # noqa: F403 +_LOGGER: logging.Logger = logging.getLogger(__name__) + + @tensorrt_converter(acc_ops.conv1d) def acc_ops_conv1d( network: TRTNetwork, @@ -641,7 +645,7 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): try: normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) except TypeError: - print("Unable to convert normalized_shape to a field, fall back to []") + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") normalized_shape = np.array([], dtype=np.int32) normalized_shape_filed = trt.PluginField( @@ -657,7 +661,9 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): else: plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") except AssertionError: - print("Unable to find layer norm plugin, fall back to TensorRT implementation.") + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) return layer_norm(network, target, args, kwargs, name) layer = network.add_plugin_v2([input_val], plugin) layer.name = name diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 6d052fc34e..470f78c407 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -197,6 +197,7 @@ def create( cls, lower_setting: LowerSetting, interpreter_builder: Callable = create_lower_trt_interpreter, + split_func: Callable = default_split_function, ) -> "Lowerer": """Instantiate a `Lowerer` instance.""" @@ -209,7 +210,7 @@ def create( ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, leaf_module_list=lower_setting.leaf_module_list, ), - split_func=default_split_function, + split_func=split_func, lower_func=default_lower_pass(interpreter_builder), ) ) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 98c6314f18..937737b60d 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -1,4 +1,5 @@ import datetime +import logging from functools import partial, wraps from typing import Any, Callable, Optional, Sequence @@ -17,6 +18,10 @@ from .lower_basic_pass import run_const_fold + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + Input = Sequence[Any] @@ -143,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - print("Now lowering submodule", submod_name) + _LOGGER.info("Now lowering submodule", submod_name) lowering_start_time = datetime.datetime.now() self.lower_setting.input_specs = generate_input_specs( @@ -160,7 +165,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: LOWER_SPLIT_POST_OBSERVER.observe( submod_name, lowered_module, submod_inputs ) - print( + _LOGGER.info( f"Lowering submodule {submod_name} elapsed time", datetime.datetime.now() - lowering_start_time, ) @@ -179,7 +184,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - print("Now lowering submodule", submod_name) + _LOGGER.info("Now lowering submodule", submod_name) lowering_start_time = datetime.datetime.now() lowered_module = self._lower_func( @@ -189,7 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: LOWER_SPLIT_POST_OBSERVER.observe( submod_name, lowered_module, submod_inputs ) - print( + _LOGGER.info( f"Lowering submodule {submod_name} elapsed time", datetime.datetime.now() - lowering_start_time, ) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index b075b744bc..d430a67408 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -116,7 +116,7 @@ def pass_with_before_after_log( encoding="utf-8", delete=False, ) as f: - print(f"== Log pass {pass_} before/after graph to {f.name}") + _LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}") print(f"[{pass_}] Before:\n{module.graph}", file=f) module = pass_(module, input) print(f"[{pass_}] After:\n{module.graph}", file=f) diff --git a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py index 1f852b0497..c91c456eb3 100644 --- a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py +++ b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py @@ -1,3 +1,4 @@ +import logging import unittest from collections import Counter from typing import Callable, Dict, List @@ -8,13 +9,16 @@ from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination +_LOGGER: logging.Logger = logging.getLogger(__name__) + + def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: """ Helper func to print model's graph in plain and tabular format, also print code. """ - print(mod_graph.graph) + _LOGGER.info(mod_graph.graph) mod_graph.graph.print_tabular() - print(mod_graph.code) + _LOGGER.info(mod_graph.code) @torch.fx.wrap @@ -46,7 +50,7 @@ def _test_opt_with_module( before_results = module(*inputs) mod_traced = acc_tracer.trace(module, inputs) before_node_list = list(mod_traced.graph.nodes) - print("Model before opt.") + _LOGGER.info("Model before opt.") debug_print_graph_module(mod_traced) # Apply Opt @@ -55,7 +59,7 @@ def _test_opt_with_module( # After Opt after_results = mod_traced(*inputs) after_node_list = list(mod_traced.graph.nodes) - print("Model after opt.") + _LOGGER.info("Model after opt.") mod_traced.recompile() debug_print_graph_module(mod_traced) diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index ab7f932acf..3abba43ccb 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -1,5 +1,5 @@ # Owner(s): ["oncall: fx"] - +import logging import unittest from typing import Callable, List @@ -16,6 +16,8 @@ torch.manual_seed(0) +_LOGGER: logging.Logger = logging.getLogger(__name__) + class AccTracerTest(unittest.TestCase): def _make_model_unit_test( @@ -258,7 +260,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 ) traced = acc_tracer.trace(m, [input]) - print(traced.graph) + _LOGGER.info(traced.graph) ph = weight_attr = bias_attr = conv = None for node in traced.graph.nodes: if node.op == "placeholder": @@ -626,7 +628,7 @@ def run_embedding_bag_test(is_4bit, use_weights): ) traced = acc_tracer.trace(m, inputs) - print(traced.graph) + _LOGGER.info(traced.graph) expected_target = ( acc_ops.embedding_bag_4bit_rowwise_offsets diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py index 467f3ca9af..3ce3b7ade8 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: gpu_enablement"] import functools import glob +import logging import os import shutil import tempfile @@ -10,6 +11,9 @@ import torch_tensorrt.fx.diagnostics as diag +_LOGGER: logging.Logger = logging.getLogger(__name__) + + def reset_diag(fn): @functools.wraps(fn) def reset(*a, **kw): @@ -53,7 +57,7 @@ def boom() -> str: zip_fn = collector._last_zip_path_for_test assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") _check_file(tempdir, "bbb", "world") @@ -78,7 +82,7 @@ def test_condition_func_name(self): zip_fn = collector._last_zip_path_for_test assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") @@ -160,7 +164,7 @@ def _test_cond( if should_collect: assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") else: diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_observer.py b/py/torch_tensorrt/fx/test/trt_lower/test_observer.py index 8a621c476a..58f23c0a13 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_observer.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_observer.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: gpu_enablement"] import functools +import logging import typing as t from contextlib import contextmanager from unittest import TestCase @@ -7,6 +8,8 @@ import torch_tensorrt.fx.observer as ob from torch_tensorrt.fx.observer import observable +_LOGGER: logging.Logger = logging.getLogger(__name__) + def set_observer_callback_rethrow(fn): """ @@ -36,7 +39,7 @@ def foo(x, y, z): @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") assert ctx.callable is foo.orig_func assert ctx.args == (1, 2) assert ctx.kwargs == {"z": 3} @@ -44,7 +47,7 @@ def log_pre(ctx: ob.ObserveContext) -> None: @verify_execution def log_post(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") assert ctx.callable is foo.orig_func assert ctx.args == (1, 2) assert ctx.kwargs == {"z": 3} @@ -57,11 +60,11 @@ def log_post(ctx: ob.ObserveContext) -> None: @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") @verify_execution def log_post(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") foo.observers.pre.add(log_pre) foo.observers.post.add(log_post) @@ -71,11 +74,11 @@ def log_post(ctx: ob.ObserveContext) -> None: @verify_execution def f1(ctx: ob.ObserveContext) -> None: - print(f"calling f1: {ctx}") + _LOGGER.info(f"calling f1: {ctx}") @verify_execution def f2(ctx: ob.ObserveContext) -> None: - print(f"calling f2: {ctx}") + _LOGGER.info(f"calling f2: {ctx}") # Test that we can register the same observation point twice with foo.observers.pre.add(f1): @@ -91,7 +94,7 @@ def foo(x, y, z): @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") raise CallbackError("TEST CALLBACK EXCEPTION") with foo.observers.pre.add(log_pre): diff --git a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py index 76da8c2430..483a45d639 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer @@ -1126,6 +1126,47 @@ def test_splitter(splitter): test_splitter(splitter) + def test_exclude_support_node_by_name(self): + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Set sin, cos and tanh as acc node and split with settings + class CustomOpSupport(op_support.OperatorSupport): + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.tanh": None, + } + + # For unsupport relu node, this would cut graph into acc_0, gpu_1 and acc_2 + # as three sub graphs. + settings = TRTSplitterSetting() + settings.exclude_support_node_name = {"relu"} + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + } + ), + settings, + ) + res = splitter.generate_split_results() + self.assertTrue(len(res), 3) + def op_support_with_support_dict(support_dict: dict) -> op_support.OperatorSupportBase: return op_support.OperatorSupport(support_dict) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index d6c635b402..a2ef83b57c 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -1,3 +1,4 @@ +import logging import time import unittest from typing import Callable, List, Tuple @@ -13,6 +14,8 @@ from torch_tensorrt.fx.passes.pass_utils import chain_passes from torch_tensorrt.fx.utils import LowerPrecision +_LOGGER: logging.Logger = logging.getLogger(__name__) + def fetch_attr(mod, target): """ @@ -65,7 +68,7 @@ def run_test( start = time.perf_counter() interpreter_result = interpreter.run(lower_precision=precision) sec = time.perf_counter() - start - print("Interpreter run time(s):", sec) + _LOGGER.info(f"Interpreter run time(s): {sec}") trt_mod = TRTModule( interpreter_result.engine, interpreter_result.input_names, @@ -81,7 +84,9 @@ def run_test( outputs = trt_mod(*cuda_inputs) end_event.record() torch.cuda.synchronize() - print("TRT run time(s)=", (start_event.elapsed_time(end_event) * 1.0e-3)) + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) if isinstance(outputs, torch.Tensor): ref_outputs = [ref_outputs] diff --git a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py index 3008c8e087..7ca6702fb2 100644 --- a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py +++ b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py @@ -1,10 +1,11 @@ import argparse +import logging import re from typing import Any, Dict, List, NamedTuple, Optional, Tuple import pydot - +_LOGGER: logging.Logger = logging.getLogger(__name__) """ log_file is generated by tensorrt verbose logger during building engine. profile_file is generated by tensorrt profiler. @@ -106,7 +107,7 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) ) ) else: - print(f"Missing node {input_name}") + _LOGGER.info(f"Missing node {input_name}") from_node = input_name else: from_node = output_name2node[input_name] @@ -213,5 +214,5 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) est_reformat_time += float(layer.time[:-2]) est_total_time += float(layer.time[:-2]) - print(f"Time Cost on Reformatting: {est_reformat_time} ms") - print(f"Total Time Cost: {est_total_time} ms") + _LOGGER.info(f"Time Cost on Reformatting: {est_reformat_time} ms") + _LOGGER.info(f"Total Time Cost: {est_total_time} ms") diff --git a/py/torch_tensorrt/fx/tools/trt_minimizer.py b/py/torch_tensorrt/fx/tools/trt_minimizer.py index 23039bea34..308687e0c9 100644 --- a/py/torch_tensorrt/fx/tools/trt_minimizer.py +++ b/py/torch_tensorrt/fx/tools/trt_minimizer.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Callable, Tuple import torch @@ -6,6 +7,8 @@ from .. import InputTensorSpec, TRTInterpreter, TRTModule +_LOGGER: logging.Logger = logging.getLogger(__name__) + def lower_mod_default( mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048 @@ -64,5 +67,5 @@ def run_b(self, mod, inputs): def get_nodes(self, start=None, end=None, enable_print=False): nodes = self._collect_nodes(start, end) if enable_print: - print(f"Nodes fetched from start {start} to end {end} as: {nodes}") + _LOGGER.info(f"Nodes fetched from start {start} to end {end} as: {nodes}") return nodes diff --git a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py index 2d5ce0b419..59d2f49042 100644 --- a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py +++ b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py @@ -1,4 +1,5 @@ import json +import logging import operator from typing import List, Mapping, Optional @@ -7,6 +8,8 @@ from .. import TRTModule +_LOGGER: logging.Logger = logging.getLogger(__name__) + class SortedTRTProfiler(trt.IProfiler): def __init__(self): @@ -22,7 +25,7 @@ def print_sorted_profile( additional_info = {} if additional_info is None else additional_info for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)): additional_str = additional_info.get(k, "") - print(f"{k} {additional_str}: {v}ms") + _LOGGER.info(f"{k} {additional_str}: {v}ms") def profile_trt_module( diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 32dc4a1853..7fbca8d99a 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -16,7 +16,10 @@ from ..tools.trt_minimizer import TensorRTMinimizer -def create_trt_operator_support(use_implicit_batch_dim=True) -> ops.OperatorSupportBase: +def create_trt_operator_support( + use_implicit_batch_dim=True, + exclude_support_node_name: set = (), +) -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" # Create an `OperatorSupport` that declares a node supported if it # finds a registered TRT converter. @@ -30,6 +33,7 @@ def create_trt_operator_support(use_implicit_batch_dim=True) -> ops.OperatorSupp supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) return ops.chain( + ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), # 1. Node is not supported if it has args with int64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), # 2. Node is supported if it has TRT converter: @@ -45,6 +49,7 @@ def __init__(self): # During split, we'll split out the operators that # don't support the batch dim. self.use_implicit_batch_dim: bool = True + self.exclude_support_node_name: set = set() class TRTSplitter(splitter_base._SplitterBase): @@ -59,7 +64,7 @@ def __init__( settings = TRTSplitterSetting() if not operator_support: operator_support = create_trt_operator_support( - settings.use_implicit_batch_dim + settings.use_implicit_batch_dim, settings.exclude_support_node_name ) super().__init__( module, diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index 408744726e..fd2c26ac2f 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -1,4 +1,5 @@ import inspect +import logging import re from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union @@ -8,6 +9,7 @@ from . import acc_utils +_LOGGER: logging.Logger = logging.getLogger(__name__) # Need to keep up-to-date with https://fburl.com/codesearch/7r2hhh53 ALIAS_MAP = { "input": ("input", "x", "a", "x1"), @@ -417,7 +419,7 @@ def normalize_to_acc_op( node, normalization_info.arg_replacement_tuples ) except Exception: - print( + _LOGGER.error( f"Error during kwarg normalization for: {node.format_node()}; " f"arg_replacement_tuples={normalization_info.arg_replacement_tuples}" ) @@ -441,7 +443,7 @@ def normalize_to_acc_op( node, normalization_info, normalized_args, normalized_kwargs ) except Exception: - print(f"Error during normalization for node: {node.format_node()}") + _LOGGER.error(f"Error during normalization for node: {node.format_node()}") raise # If there are any dead nodes left after normalization, eliminate them now. diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index be6b6700a1..d1a5322316 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -528,6 +528,7 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return cat_node +@register_acc_op_properties(AccOpProperty.pointwise) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) @register_acc_op @@ -724,6 +725,13 @@ def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return new_node +@register_acc_op_mapping( + op_and_target=("call_method", "mm"), + arg_replacement_tuples=[ + ("input", "input"), + ("mat2", "other"), + ], +) @register_acc_op_mapping( op_and_target=("call_function", operator.matmul), arg_replacement_tuples=[ diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index c535b062ee..57f7d0e7ea 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -439,7 +439,9 @@ def __init__(self, orig): if k == "_modules": for mod_k, mod_v in v.items(): if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator] - print(f"Skip rewriting leaf module {type(mod_v)}") + _LOGGER.info( + f"Skip rewriting leaf module {type(mod_v)}" + ) self._modules[mod_k] = mod_v else: self._modules[mod_k] = rewrite_module(mod_v) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index 425fafddac..4c3a79dc4c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -1,5 +1,6 @@ import inspect import json +import logging import os import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -12,6 +13,8 @@ from torch.fx.passes import graph_drawer from torch.fx.passes.shape_prop import TensorMetadata +_LOGGER: logging.Logger = logging.getLogger(__name__) + def get_target_from_module(mod: torch.nn.Module, target: str): """ @@ -92,13 +95,13 @@ def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_grap base, ext = os.path.splitext(fname) if not ext: ext = ".svg" - print(f"Writing FX graph to file: {base}{ext}") + _LOGGER.info(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer(traced, figname) x = g.get_main_dot_graph() try: getattr(x, "write_" + ext.lstrip("."))(fname) except OSError as e: - print(f"Failed to write the FX graph due to: {e}") + _LOGGER.error(f"Failed to write the FX graph due to: {e}") def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None):