diff --git a/py/torch_tensorrt/fx/Dynamic_Shape_Support.md b/py/torch_tensorrt/fx/Dynamic_Shape_Support.md new file mode 100644 index 0000000000..eb4454340e --- /dev/null +++ b/py/torch_tensorrt/fx/Dynamic_Shape_Support.md @@ -0,0 +1,137 @@ +# PyTorch Operations Dynamic Shape Support Summary + + + + | Operation | Test Method | Supports Dynamic Shape | Shape | Num of dimensions | Reason | +| --- | --- | --- | --- | --- | --- | +| adaptive_avgpool | | partially | (-1, -1, 256, 256) | 2 | AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims. | +| any | | no | | | torch.zeros(tuple(\[*input_t.shape\])). Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| as_strided | | no | | | RuntimeError: setStorage: sizes \[2, 3\], strides \[1, 2\], storage offset 0, and itemsize 8 requiring a storage size of 48 are out of bounds for storage of size 16 | +| avg_pool | avg_pool2d | yes | (-1,-,1,-1,-1) | 4 | | +| | avg_pool1d | partially | (-1, 3, 3) | 1 | | +| batchnorm | | partially | (-1, 3, -1, -1) | 3 | "Channel dim can't be dynamic for batch norm." | +| binary_ops | | yes | (-1,-,1,-1,-1) | 4 | | +| cat | | yes | (-1,-,1,-1,-1) | 4 | | +| chunk | | partially | (-1, 1, 3, -1) | any (not chunk dim) | AssertionError: Can't chunk on dynamic shape dimension! | +| clamp | | yes | (-1,-,1,-1,-1) | | | +| convolution | conv2d | partially | (-1, 3, -1, -1) | 3 | AssertionError: Channel dim can't be dynamic for convolution. | +| | conv1d | partially | (-1, 3, 3) | 1 | | +| | conv3d | partially | (-1,-,1,-1,-1) | 4 | AssertionError: Channel dim can't be dynamic for convolution. | +| dequantize | | yes | (-1,-,1,-1,-1) | 4 | | +| eimsum | | yes | (-1,-,1,-1,-1) | 4 | | +| elu | | yes | (-1,-,1,-1,-1) | 4 | | +| embedding | | yes | (-1,-,1,-1,-1) | 4 | | +| eq | SimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | EqMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperatorConstant | partially | (3,-1) | 1 | | +| | EqConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| expand | | no | | | Dynamic shape is not suitable for the expand operation. | +| flatten | | yes | (-1, -1, -1, -1, -1) | 5 | | +| gelu | | yes | (-1,-,1,-1,-1) | 4 | | +| getitem | | yes | (-1,-,1,-1,-1) | 4 | | +| gt | EqOperatorSimpleConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | ConstInputConverter | yes | (-1,-,1,-1,-1) | 4 | | +| | GtConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtMethodConverter | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | GtOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| | EqOperator | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| hardsigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| hardtanh | | yes | (-1,-,1,-1,-1) | 4 | | +| interpolate | | yes | (-1,-,1,-1,-1) | 4 | | +| isinf | | yes | (-1,-,1,-1,-1) | 4 | | +| leaky_relu | | yes | (-1,-,1,-1,-1) | 4 | | +| linear | | partially | (-1, 3, 5) | 1 | AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. | +| logical_and | | yes | (-1, -1, -1, -1) | 4 | | +| logical_or | | yes | (-1, -1, -1, -1) | 4 | | +| logical_xor | | yes | (-1, -1, -1, -1) | 4 | | +| lt | | yes | (-1, -1, -1, -1) | 4 | | +| masked_fill | | no | limitation in converter | | RuntimeError: Trying to create tensor with negative dimension -1: \[-1, -1, -1, -1\] | +| mat_mul | | yes | batch dim | | | +| max | MaxFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MaxMethod | yes | (-1, -1, -1, -1) | 4 | | +| maximum | | yes | (-1, -1, -1, -1) | 4 | | +| maxpool | max_pool1d | partially | (1, 1, -1) | 1 | shape is not set to (-1, -1, -1) as reshape dimension with, more than one -1 wildcard is not allowed while adding unsqueeze layer | +| | max_pool2d | yes | (-1, -1, -1, -1) | 4 | | +| | max_pool3d | yes | (-1, -1, -1, -1, -1) | 5 | | +| min | MinFullReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinDimReduce | yes | (-1, -1, -1, -1) | 4 | | +| | MinMethod | yes | (-1, -1, -1, -1) | 4 | | +| minimum | | yes | (-1, -1, -1, -1) | 4 | | +| narrow | | partially | (-1, 3, -1, -1) | 3 | AssertionError: Can't chunk on dynamic shape dimension! | +| ne | NeFunctionConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeMethodConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConverter | yes | (-1, -1, -1, -1) | 4 | | +| | ConstInputConverter | yes | (-1, -1, -1, -1) | 4 | | +| | NeOperatorConstantConverter | partially | (3, -1) | 1 | | +| new_ones | | yes | (-1, -1, -1, -1) | 4 | | +| numel | | no | limitation in converter | | RuntimeError: numel does not support dynamic shapes. | +| pad | | no | limitation in converter | | test\_pad\_with\_dynamic\_shape\_four\_dimensions\_0\_2d (deeplearning.trt.torch\_tensorrt.py.torch\_tensorrt.fx.test.converters.acc\_op.test\_pad.TestPadConverter) ... \[07/15/2022-09:23:18\] \[TRT\] \[E\] 2: \[intInterval.cpp::max::26\] Error Code 2: Internal Error (Assertion !empty() failed. | +| permute | | yes | (-1, -1, -1, -1) | 4 | | +| prod | | yes | (-1, -1, -1, -1) | 4 | | +| quantize\_per\_tensor | | yes | (-1, -1, -1, -1) | 4 | | +| reduce op | | yes | (-1, -1, -1, -1) | 4 | | +| relu | | yes | (-1, -1, -1, -1) | 4 | | +| repeat interleave | | partially | (-1, 3, 2) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| reshape | | yes | (-1, -1, -1, -1) | 4 | | +| selu | | yes | (-1, -1, -1, -1) | 4 | | +| sigmoid | | yes | (-1,-,1,-1,-1) | 4 | | +| silu | | yes | (-1,-,1,-1,-1) | 4 | | +| size | | yes | (-1, -1, -1, -1) | 4 | | +| softmax | | yes | (-1, -1, -1, -1) | 4 | | +| softsign | | yes | (-1, -1, -1, -1) | 4 | | +| split | | partially | (-1, 10, -1) | 2 | AssertionError: Can't chunk on dynamic shape dimension! | +| squeeze | | partially | (1, -1, 2) | 1 | AssertionError: Currently more than one dynamic dim for input to squeeze is not supported. | +| std | | yes | (-1, -1, -1, -1) | 4 | | +| tanh | | yes | (-1, -1, -1, -1) | 4 | | +| tile | | yes | (-1, -1, -1, -1) | 4 | | +| to_dtype | int | yes | (-1, -1, -1, -1) | 4 | | +| | float | yes | (-1, -1, -1, -1) | 4 | | +| topk | | yes | (-1, -1, -1, -1) | 4 | | +| transpose_convolution | conv_transpose2d | partially | (-1, 3, -1, -1) | 3 | | +| | conv_transpose3d | partially | (-1, 3, -1, -1, -1) | 4 | | +| type_as | | yes | (-1, -1, -1, -1) | 4 | RuntimeError: ShapeProp error for: node=%type\_1 : \[#users=1\] = call\_method\[target=type\](args = (%input_1,), kwargs = {dtype: torch.float32}) with meta={} | +| unary ops | | yes | (-1, -1, -1, -1) | 4 | | +| unsqueeze | | partially | (-1, 2, 3) | 1 | AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. | +| where | | no | limitation in converter | | torch.broadcast_shape can not handle -1 dimension in shape \[-1, 2, 2\] | + + + +Binary Ops Include following operations: +|Binary Ops | +|----------| +|add | +|sub | +|div | +|mul | +|floor_div | +|fmod | +|floor_divide| +|pow | + + +Unary Ops Include following operations: +|Unary Ops | +|----------| +|rsqrt | +|sin | +|cos | +|tan | +|sinh | +|cosh | +|asin | +|acos | +|atan | +|abs | +|neg | +|reciprocal| +|sqrt | +|log | +|exp | +|floor | +|ceil | +|sign | + +Note: For more information about the test method, please refer to the operation test files. Additionally, test files include information about errors encountered during dynamic shape testing. diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 95436e762e..68334ebe44 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -1,4 +1,5 @@ # flake8: noqa +import logging import math import operator import warnings @@ -22,6 +23,9 @@ from .converter_utils import * # noqa: F403 +_LOGGER: logging.Logger = logging.getLogger(__name__) + + @tensorrt_converter(acc_ops.conv1d) def acc_ops_conv1d( network: TRTNetwork, @@ -641,7 +645,7 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): try: normalized_shape = np.array(kwargs["normalized_shape"], dtype=np.int32) except TypeError: - print("Unable to convert normalized_shape to a field, fall back to []") + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") normalized_shape = np.array([], dtype=np.int32) normalized_shape_filed = trt.PluginField( @@ -657,7 +661,9 @@ def acc_ops_layer_norm(network, target, args, kwargs, name): else: plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") except AssertionError: - print("Unable to find layer norm plugin, fall back to TensorRT implementation.") + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) return layer_norm(network, target, args, kwargs, name) layer = network.add_plugin_v2([input_val], plugin) layer.name = name diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 6d052fc34e..470f78c407 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -197,6 +197,7 @@ def create( cls, lower_setting: LowerSetting, interpreter_builder: Callable = create_lower_trt_interpreter, + split_func: Callable = default_split_function, ) -> "Lowerer": """Instantiate a `Lowerer` instance.""" @@ -209,7 +210,7 @@ def create( ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, leaf_module_list=lower_setting.leaf_module_list, ), - split_func=default_split_function, + split_func=split_func, lower_func=default_lower_pass(interpreter_builder), ) ) diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 98c6314f18..937737b60d 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -1,4 +1,5 @@ import datetime +import logging from functools import partial, wraps from typing import Any, Callable, Optional, Sequence @@ -17,6 +18,10 @@ from .lower_basic_pass import run_const_fold + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + Input = Sequence[Any] @@ -143,7 +148,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - print("Now lowering submodule", submod_name) + _LOGGER.info("Now lowering submodule", submod_name) lowering_start_time = datetime.datetime.now() self.lower_setting.input_specs = generate_input_specs( @@ -160,7 +165,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: LOWER_SPLIT_POST_OBSERVER.observe( submod_name, lowered_module, submod_inputs ) - print( + _LOGGER.info( f"Lowering submodule {submod_name} elapsed time", datetime.datetime.now() - lowering_start_time, ) @@ -179,7 +184,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: # Only acc submodules will be lowered. if not submod_name.startswith(split_result.non_acc_submodule_prefix): - print("Now lowering submodule", submod_name) + _LOGGER.info("Now lowering submodule", submod_name) lowering_start_time = datetime.datetime.now() lowered_module = self._lower_func( @@ -189,7 +194,7 @@ def lower_func(split_result: SplitResult) -> nn.Module: LOWER_SPLIT_POST_OBSERVER.observe( submod_name, lowered_module, submod_inputs ) - print( + _LOGGER.info( f"Lowering submodule {submod_name} elapsed time", datetime.datetime.now() - lowering_start_time, ) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index b075b744bc..d430a67408 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -116,7 +116,7 @@ def pass_with_before_after_log( encoding="utf-8", delete=False, ) as f: - print(f"== Log pass {pass_} before/after graph to {f.name}") + _LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}") print(f"[{pass_}] Before:\n{module.graph}", file=f) module = pass_(module, input) print(f"[{pass_}] After:\n{module.graph}", file=f) diff --git a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py index 1f852b0497..c91c456eb3 100644 --- a/py/torch_tensorrt/fx/test/passes/test_graph_opts.py +++ b/py/torch_tensorrt/fx/test/passes/test_graph_opts.py @@ -1,3 +1,4 @@ +import logging import unittest from collections import Counter from typing import Callable, Dict, List @@ -8,13 +9,16 @@ from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination +_LOGGER: logging.Logger = logging.getLogger(__name__) + + def debug_print_graph_module(mod_graph: torch.fx.GraphModule) -> None: """ Helper func to print model's graph in plain and tabular format, also print code. """ - print(mod_graph.graph) + _LOGGER.info(mod_graph.graph) mod_graph.graph.print_tabular() - print(mod_graph.code) + _LOGGER.info(mod_graph.code) @torch.fx.wrap @@ -46,7 +50,7 @@ def _test_opt_with_module( before_results = module(*inputs) mod_traced = acc_tracer.trace(module, inputs) before_node_list = list(mod_traced.graph.nodes) - print("Model before opt.") + _LOGGER.info("Model before opt.") debug_print_graph_module(mod_traced) # Apply Opt @@ -55,7 +59,7 @@ def _test_opt_with_module( # After Opt after_results = mod_traced(*inputs) after_node_list = list(mod_traced.graph.nodes) - print("Model after opt.") + _LOGGER.info("Model after opt.") mod_traced.recompile() debug_print_graph_module(mod_traced) diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index ab7f932acf..3abba43ccb 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -1,5 +1,5 @@ # Owner(s): ["oncall: fx"] - +import logging import unittest from typing import Callable, List @@ -16,6 +16,8 @@ torch.manual_seed(0) +_LOGGER: logging.Logger = logging.getLogger(__name__) + class AccTracerTest(unittest.TestCase): def _make_model_unit_test( @@ -258,7 +260,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: torch.randn(1, 3, 1, 1), scale=0.01, zero_point=3, dtype=torch.quint8 ) traced = acc_tracer.trace(m, [input]) - print(traced.graph) + _LOGGER.info(traced.graph) ph = weight_attr = bias_attr = conv = None for node in traced.graph.nodes: if node.op == "placeholder": @@ -626,7 +628,7 @@ def run_embedding_bag_test(is_4bit, use_weights): ) traced = acc_tracer.trace(m, inputs) - print(traced.graph) + _LOGGER.info(traced.graph) expected_target = ( acc_ops.embedding_bag_4bit_rowwise_offsets diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py index 467f3ca9af..3ce3b7ade8 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: gpu_enablement"] import functools import glob +import logging import os import shutil import tempfile @@ -10,6 +11,9 @@ import torch_tensorrt.fx.diagnostics as diag +_LOGGER: logging.Logger = logging.getLogger(__name__) + + def reset_diag(fn): @functools.wraps(fn) def reset(*a, **kw): @@ -53,7 +57,7 @@ def boom() -> str: zip_fn = collector._last_zip_path_for_test assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") _check_file(tempdir, "bbb", "world") @@ -78,7 +82,7 @@ def test_condition_func_name(self): zip_fn = collector._last_zip_path_for_test assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") @@ -160,7 +164,7 @@ def _test_cond( if should_collect: assert os.path.exists(zip_fn) with tempfile.TemporaryDirectory() as tempdir: - print(f"Unpacking into {tempdir}") + _LOGGER.info(f"Unpacking into {tempdir}") shutil.unpack_archive(zip_fn, tempdir) _check_file(tempdir, "aaa", "hello") else: diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_observer.py b/py/torch_tensorrt/fx/test/trt_lower/test_observer.py index 8a621c476a..58f23c0a13 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_observer.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_observer.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: gpu_enablement"] import functools +import logging import typing as t from contextlib import contextmanager from unittest import TestCase @@ -7,6 +8,8 @@ import torch_tensorrt.fx.observer as ob from torch_tensorrt.fx.observer import observable +_LOGGER: logging.Logger = logging.getLogger(__name__) + def set_observer_callback_rethrow(fn): """ @@ -36,7 +39,7 @@ def foo(x, y, z): @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") assert ctx.callable is foo.orig_func assert ctx.args == (1, 2) assert ctx.kwargs == {"z": 3} @@ -44,7 +47,7 @@ def log_pre(ctx: ob.ObserveContext) -> None: @verify_execution def log_post(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") assert ctx.callable is foo.orig_func assert ctx.args == (1, 2) assert ctx.kwargs == {"z": 3} @@ -57,11 +60,11 @@ def log_post(ctx: ob.ObserveContext) -> None: @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") @verify_execution def log_post(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") foo.observers.pre.add(log_pre) foo.observers.post.add(log_post) @@ -71,11 +74,11 @@ def log_post(ctx: ob.ObserveContext) -> None: @verify_execution def f1(ctx: ob.ObserveContext) -> None: - print(f"calling f1: {ctx}") + _LOGGER.info(f"calling f1: {ctx}") @verify_execution def f2(ctx: ob.ObserveContext) -> None: - print(f"calling f2: {ctx}") + _LOGGER.info(f"calling f2: {ctx}") # Test that we can register the same observation point twice with foo.observers.pre.add(f1): @@ -91,7 +94,7 @@ def foo(x, y, z): @verify_execution def log_pre(ctx: ob.ObserveContext) -> None: - print(f"calling log: {ctx}") + _LOGGER.info(f"calling log: {ctx}") raise CallbackError("TEST CALLBACK EXCEPTION") with foo.observers.pre.add(log_pre): diff --git a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py index 76da8c2430..483a45d639 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py +++ b/py/torch_tensorrt/fx/test/trt_lower/trt_splitter_test.py @@ -10,7 +10,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import splitter_base from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter +from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer @@ -1126,6 +1126,47 @@ def test_splitter(splitter): test_splitter(splitter) + def test_exclude_support_node_by_name(self): + class TestModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.relu(b) + d = torch.cos(c) + e = torch.sigmoid(d) + f = torch.tanh(e) + return f + + mod = acc_tracer.trace(TestModule(), [torch.randn(2, 3)]) + + # Set sin, cos and tanh as acc node and split with settings + class CustomOpSupport(op_support.OperatorSupport): + _support_dict = { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + "acc_ops.sigmoid": None, + "acc_ops.tanh": None, + } + + # For unsupport relu node, this would cut graph into acc_0, gpu_1 and acc_2 + # as three sub graphs. + settings = TRTSplitterSetting() + settings.exclude_support_node_name = {"relu"} + splitter = TRTSplitter( + mod, + (torch.randn(2, 3),), + op_support_with_support_dict( + { + "acc_ops.sin": None, + "acc_ops.cos": None, + "acc_ops.relu": None, + } + ), + settings, + ) + res = splitter.generate_split_results() + self.assertTrue(len(res), 3) + def op_support_with_support_dict(support_dict: dict) -> op_support.OperatorSupportBase: return op_support.OperatorSupport(support_dict) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index d6c635b402..a2ef83b57c 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -1,3 +1,4 @@ +import logging import time import unittest from typing import Callable, List, Tuple @@ -13,6 +14,8 @@ from torch_tensorrt.fx.passes.pass_utils import chain_passes from torch_tensorrt.fx.utils import LowerPrecision +_LOGGER: logging.Logger = logging.getLogger(__name__) + def fetch_attr(mod, target): """ @@ -65,7 +68,7 @@ def run_test( start = time.perf_counter() interpreter_result = interpreter.run(lower_precision=precision) sec = time.perf_counter() - start - print("Interpreter run time(s):", sec) + _LOGGER.info(f"Interpreter run time(s): {sec}") trt_mod = TRTModule( interpreter_result.engine, interpreter_result.input_names, @@ -81,7 +84,9 @@ def run_test( outputs = trt_mod(*cuda_inputs) end_event.record() torch.cuda.synchronize() - print("TRT run time(s)=", (start_event.elapsed_time(end_event) * 1.0e-3)) + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) if isinstance(outputs, torch.Tensor): ref_outputs = [ref_outputs] diff --git a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py index 3008c8e087..7ca6702fb2 100644 --- a/py/torch_tensorrt/fx/tools/engine_layer_visualize.py +++ b/py/torch_tensorrt/fx/tools/engine_layer_visualize.py @@ -1,10 +1,11 @@ import argparse +import logging import re from typing import Any, Dict, List, NamedTuple, Optional, Tuple import pydot - +_LOGGER: logging.Logger = logging.getLogger(__name__) """ log_file is generated by tensorrt verbose logger during building engine. profile_file is generated by tensorrt profiler. @@ -106,7 +107,7 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) ) ) else: - print(f"Missing node {input_name}") + _LOGGER.info(f"Missing node {input_name}") from_node = input_name else: from_node = output_name2node[input_name] @@ -213,5 +214,5 @@ def build_edge(layer, graph, reformat_layers, output_name2node, layer_name2node) est_reformat_time += float(layer.time[:-2]) est_total_time += float(layer.time[:-2]) - print(f"Time Cost on Reformatting: {est_reformat_time} ms") - print(f"Total Time Cost: {est_total_time} ms") + _LOGGER.info(f"Time Cost on Reformatting: {est_reformat_time} ms") + _LOGGER.info(f"Total Time Cost: {est_total_time} ms") diff --git a/py/torch_tensorrt/fx/tools/trt_minimizer.py b/py/torch_tensorrt/fx/tools/trt_minimizer.py index 23039bea34..308687e0c9 100644 --- a/py/torch_tensorrt/fx/tools/trt_minimizer.py +++ b/py/torch_tensorrt/fx/tools/trt_minimizer.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Callable, Tuple import torch @@ -6,6 +7,8 @@ from .. import InputTensorSpec, TRTInterpreter, TRTModule +_LOGGER: logging.Logger = logging.getLogger(__name__) + def lower_mod_default( mod: torch.fx.GraphModule, inputs: Tensors, batch_size: Any = 2048 @@ -64,5 +67,5 @@ def run_b(self, mod, inputs): def get_nodes(self, start=None, end=None, enable_print=False): nodes = self._collect_nodes(start, end) if enable_print: - print(f"Nodes fetched from start {start} to end {end} as: {nodes}") + _LOGGER.info(f"Nodes fetched from start {start} to end {end} as: {nodes}") return nodes diff --git a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py index 2d5ce0b419..59d2f49042 100644 --- a/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py +++ b/py/torch_tensorrt/fx/tools/trt_profiler_sorted.py @@ -1,4 +1,5 @@ import json +import logging import operator from typing import List, Mapping, Optional @@ -7,6 +8,8 @@ from .. import TRTModule +_LOGGER: logging.Logger = logging.getLogger(__name__) + class SortedTRTProfiler(trt.IProfiler): def __init__(self): @@ -22,7 +25,7 @@ def print_sorted_profile( additional_info = {} if additional_info is None else additional_info for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)): additional_str = additional_info.get(k, "") - print(f"{k} {additional_str}: {v}ms") + _LOGGER.info(f"{k} {additional_str}: {v}ms") def profile_trt_module( diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 32dc4a1853..7fbca8d99a 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -16,7 +16,10 @@ from ..tools.trt_minimizer import TensorRTMinimizer -def create_trt_operator_support(use_implicit_batch_dim=True) -> ops.OperatorSupportBase: +def create_trt_operator_support( + use_implicit_batch_dim=True, + exclude_support_node_name: set = (), +) -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" # Create an `OperatorSupport` that declares a node supported if it # finds a registered TRT converter. @@ -30,6 +33,7 @@ def create_trt_operator_support(use_implicit_batch_dim=True) -> ops.OperatorSupp supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) return ops.chain( + ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), # 1. Node is not supported if it has args with int64 dtype: ops.OpSupports.decline_if_input_dtype(torch.int64), # 2. Node is supported if it has TRT converter: @@ -45,6 +49,7 @@ def __init__(self): # During split, we'll split out the operators that # don't support the batch dim. self.use_implicit_batch_dim: bool = True + self.exclude_support_node_name: set = set() class TRTSplitter(splitter_base._SplitterBase): @@ -59,7 +64,7 @@ def __init__( settings = TRTSplitterSetting() if not operator_support: operator_support = create_trt_operator_support( - settings.use_implicit_batch_dim + settings.use_implicit_batch_dim, settings.exclude_support_node_name ) super().__init__( module, diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index 408744726e..fd2c26ac2f 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -1,4 +1,5 @@ import inspect +import logging import re from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Union @@ -8,6 +9,7 @@ from . import acc_utils +_LOGGER: logging.Logger = logging.getLogger(__name__) # Need to keep up-to-date with https://fburl.com/codesearch/7r2hhh53 ALIAS_MAP = { "input": ("input", "x", "a", "x1"), @@ -417,7 +419,7 @@ def normalize_to_acc_op( node, normalization_info.arg_replacement_tuples ) except Exception: - print( + _LOGGER.error( f"Error during kwarg normalization for: {node.format_node()}; " f"arg_replacement_tuples={normalization_info.arg_replacement_tuples}" ) @@ -441,7 +443,7 @@ def normalize_to_acc_op( node, normalization_info, normalized_args, normalized_kwargs ) except Exception: - print(f"Error during normalization for node: {node.format_node()}") + _LOGGER.error(f"Error during normalization for node: {node.format_node()}") raise # If there are any dead nodes left after normalization, eliminate them now. diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index be6b6700a1..d1a5322316 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -528,6 +528,7 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return cat_node +@register_acc_op_properties(AccOpProperty.pointwise) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) @register_acc_op @@ -724,6 +725,13 @@ def square_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: return new_node +@register_acc_op_mapping( + op_and_target=("call_method", "mm"), + arg_replacement_tuples=[ + ("input", "input"), + ("mat2", "other"), + ], +) @register_acc_op_mapping( op_and_target=("call_function", operator.matmul), arg_replacement_tuples=[ diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index c535b062ee..57f7d0e7ea 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -439,7 +439,9 @@ def __init__(self, orig): if k == "_modules": for mod_k, mod_v in v.items(): if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator] - print(f"Skip rewriting leaf module {type(mod_v)}") + _LOGGER.info( + f"Skip rewriting leaf module {type(mod_v)}" + ) self._modules[mod_k] = mod_v else: self._modules[mod_k] = rewrite_module(mod_v) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index 425fafddac..4c3a79dc4c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -1,5 +1,6 @@ import inspect import json +import logging import os import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -12,6 +13,8 @@ from torch.fx.passes import graph_drawer from torch.fx.passes.shape_prop import TensorMetadata +_LOGGER: logging.Logger = logging.getLogger(__name__) + def get_target_from_module(mod: torch.nn.Module, target: str): """ @@ -92,13 +95,13 @@ def draw_graph(traced: torch.fx.GraphModule, fname: str, figname: str = "fx_grap base, ext = os.path.splitext(fname) if not ext: ext = ".svg" - print(f"Writing FX graph to file: {base}{ext}") + _LOGGER.info(f"Writing FX graph to file: {base}{ext}") g = graph_drawer.FxGraphDrawer(traced, figname) x = g.get_main_dot_graph() try: getattr(x, "write_" + ext.lstrip("."))(fname) except OSError as e: - print(f"Failed to write the FX graph due to: {e}") + _LOGGER.error(f"Failed to write the FX graph due to: {e}") def get_model_info_str(gm: torch.fx.GraphModule, header: Optional[str] = None):