diff --git a/docs/_sources/tutorials/ptq.rst.txt b/docs/_sources/tutorials/ptq.rst.txt index b62457109f..75b27c8409 100644 --- a/docs/_sources/tutorials/ptq.rst.txt +++ b/docs/_sources/tutorials/ptq.rst.txt @@ -136,7 +136,7 @@ Then all thats required to setup the module for INT8 calibration is to set the f If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well. From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the Torch-TensorRT demo that takes you from training a VGG16 network on -CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/cpp/ptq +CIFAR10 to deploying in INT8 with Torch-TensorRT here: https://github.com/pytorch/TensorRT/tree/master/examples/int8/ptq .. _writing_ptq_python: @@ -194,8 +194,8 @@ to use ``CacheCalibrator`` to use in INT8 mode. calibrator=calibrator) If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient. -For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py -and https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_trt_calibrator.py +For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_dataloader_calibrator.py +and https://github.com/pytorch/TensorRT/blob/master/tests/py/ptq/test_ptq_trt_calibrator.py Citations ^^^^^^^^^^^ diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9135ebc98a..f5278b1d07 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -2854,8 +2854,12 @@ def add_clamp(network, input, val, op, name): else: acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions acc_ops_clamp_tensor = ( - val - * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)) + ( + val + * torch.ones( + acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype) + ) + ) .cpu() .numpy() ) diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index f7f554e1c6..844fa24238 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -1,10 +1,13 @@ import copy +import logging import operator import warnings -from typing import Any +from typing import Any, Optional import torch import torch.fx +import torch.fx as fx +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch.fx.experimental.const_fold import split_const_subgraphs from ..observer import observable @@ -13,6 +16,8 @@ from ..tracer.acc_tracer.acc_utils import get_attr from .pass_utils import log_before_after, validate_inference +_LOGGER = logging.getLogger(__name__) + # Create an alias for module input type to avoid littering pyre-ignore for Any # throughout the file. Input = Any @@ -460,3 +465,146 @@ def transform_setitem(gm: torch.fx.GraphModule, input: Input): gm.graph.lint() gm.recompile() return gm + + +def fix_reshape_batch_dim(mod: fx.GraphModule) -> fx.GraphModule: + """\ + TRT cannot reason about shape patterns like x.reshape(y.size(0), -1, 256), + since the dynamic shape of the reshape comes from the dynamic shape of + another node (y). The compilation will fail with various memory related + errors, depending on the size of the input tensor. + + This pass fixes the issue by finding this reshape pattern, checking that: + + x.size(0) == y.size(0) + + And then replaces reshape's batch size from y.size(0) to x.size(0). + """ + + def get_reshape_batch_size_as_node(maybe_reshape: fx.Node) -> Optional[fx.Node]: + """\ + Try to find the reshape op's batch size as an input node. + + Match below graph structure and return `node_y`: + node_x.reshape({"acc_out_ty": {"shape": (node_y, ...)}}) + """ + if ( + maybe_reshape.op != "call_function" + or maybe_reshape.target != acc_ops.reshape + ): + return None + shape = getattr(maybe_reshape.kwargs["acc_out_ty"], "shape", None) + if not shape: + return None + batch_size = shape[0] + if isinstance(batch_size, fx.Node): + return batch_size + return None + + def get_reshape_batch_size_inferred_source( + batch_size_node: fx.Node, + ) -> Optional[fx.Node]: + """\ + Given a node representing the batch size used for reshape op, we want + to know if it is coming from below pattern: + + batch_size_node = src.size()[0] + + or in IR graph: + + src -> size(input=_) -> getitem(input=_, idx=0) + ^ ~~~ batch_size_node + + If so, return `src`. Otherwise, return `None`. + """ + if ( + batch_size_node.op != "call_function" + or batch_size_node.target != acc_ops.getitem + or batch_size_node.kwargs["idx"] != 0 + ): + return None + maybe_size: fx.Node = batch_size_node.all_input_nodes[0] + if maybe_size.op != "call_function" or maybe_size.target != acc_ops.size: + return None + return maybe_size.all_input_nodes[0] + + maybe_reshape: fx.Node + for maybe_reshape in mod.graph.nodes: + reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( + maybe_reshape + ) + if not reshape_batch_size: + continue + reshape_batch_size_inferred_source: Optional[ + fx.Node + ] = get_reshape_batch_size_inferred_source(reshape_batch_size) + if not reshape_batch_size_inferred_source: + continue + + reshape_input: fx.Node = maybe_reshape.kwargs["input"] + if reshape_input == reshape_batch_size_inferred_source: + continue + + if not _is_batch_size_equal(reshape_input, reshape_batch_size_inferred_source): + continue + + _LOGGER.info( + f"{fix_reshape_batch_dim}: Found bad pattern: y.reshape((x, ...)). Reshape node: {maybe_reshape}" + ) + + # Step 1: create a node to compute batch size, using the tensor which + # is being reshaped: reshape_input.size()[0]. This batch size is now + # derived from reshape_input, the same node as the reshape op's input. + with mod.graph.inserting_before(maybe_reshape): + reshape_batch_size_2: fx.Node = maybe_reshape.graph.call_function( + acc_ops.getitem, + kwargs={ + "idx": 0, + "input": maybe_reshape.graph.call_function( + acc_ops.size, + kwargs={ + "input": reshape_input, + }, + ), + }, + ) + + # Step 2: update `maybe_reshape`'s shape argument to be + # (reshape_batch_size_2, *DONT_CARE_JUST_COPY_OVER) + maybe_reshape.kwargs = { + **maybe_reshape.kwargs, + "acc_out_ty": acc_utils.build_raw_tensor_meta( + shape=( + reshape_batch_size_2, + *(maybe_reshape.kwargs["acc_out_ty"].shape[1:]), + ) + ), + } + + mod.graph.eliminate_dead_code() + mod.recompile() + return mod + + +def _is_batch_size_equal(x: fx.Node, y: fx.Node) -> bool: + """\ + Check that x.size(0) == y.size(0) + """ + x_size, y_size = _get_shape(x), _get_shape(y) + return ( + x_size + and y_size + # now both are non-empty + and x_size[0] == y_size[0] + ) + + +def _get_shape(node: fx.Node) -> Optional[torch.Size]: + if ( + not getattr(node, "meta", None) + or not node.meta.get("tensor_meta", None) + or not getattr(node.meta["tensor_meta"], "shape", None) + ): + # shape info not available + return None + return node.meta["tensor_meta"].shape diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 877029cd44..c4bb927b85 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -17,6 +17,7 @@ from .graph_opts import common_subexpression_elimination from .lower_basic_pass import ( + fix_reshape_batch_dim, replace_mutable_op, replace_op_with_indices, run_const_fold, @@ -112,6 +113,7 @@ def graph_optimization_pass(self) -> PassManager: passes.append( inplace_wrapper(lambda m: FUSE_PASSES_POST_OBSERVER.observe(m, self._input)) ) + passes.append(fix_reshape_batch_dim) return PassManager.build_from_passlist(passes) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 9db173f1e1..78e9ec1b22 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,3 +1,4 @@ +import io import logging import tempfile from functools import wraps @@ -233,15 +234,23 @@ def log_before_after(pass_: PassFunc) -> PassFunc: def pass_with_before_after_log( module: fx.GraphModule, input: Input ) -> fx.GraphModule: + before_io = io.StringIO() + after_io = io.StringIO() with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", delete=False, ) as f: - _LOGGER.info(f"== Log pass {pass_} before/after graph to {f.name}") print(f"[{pass_}] Before:\n{module.graph}", file=f) + print(module.graph, file=before_io) + module = pass_(module, input) print(f"[{pass_}] After:\n{module.graph}", file=f) + print(module.graph, file=after_io) + t = before_io.getvalue() == after_io.getvalue() + _LOGGER.info( + f"== Log pass {pass_} before/after graph to {f.name}, before/after are the same = {t}" + ) return module return pass_with_before_after_log diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index 7c166c1fe0..e59153d5c9 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -12,6 +12,7 @@ class TestClampConverter(AccTestCase): param("min", min=0.5), param("max", max=0.5), param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), ] ) def test_clamp( diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py new file mode 100644 index 0000000000..bd04692ad5 --- /dev/null +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -0,0 +1,51 @@ +# Owner(s): ["oncall: gpu_enablement"] + +import logging +from copy import deepcopy + +import torch +import torch.fx as fx +import torch.nn as nn + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch_tensorrt.fx.passes.lower_basic_pass import fix_reshape_batch_dim +from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer + +_LOGGER = logging.getLogger(__name__) + + +class TestFixReshapeBatchDim(TestCase): + def test_fix_reshape_batch_dim(self): + class Repro(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return y.view(x.size(0), -1, 3) + + mod = Repro() + modt = fx.symbolic_trace(mod) + inp = [ + torch.rand([10, 60]), + torch.rand([10, 60]), + ] + mod(*inp) + mod_acc_traced = acc_tracer.trace(modt, inp) + mod_fixed = fix_reshape_batch_dim(deepcopy(mod_acc_traced)) + + expected_graph = r""" +graph(): + %x : [#users=0] = placeholder[target=x] + %y : [#users=2] = placeholder[target=y] + %size : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.size](args = (), kwargs = {input: %y}) + %getitem_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.getitem](args = (), kwargs = {idx: 0, input: %size}) + %reshape : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) + return reshape +""" + assert ( + str(mod_fixed.graph).strip() == expected_graph.strip() + ), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}" + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index c3779ef933..23b7329669 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -2566,6 +2566,31 @@ def forward(self, x: List[torch.Tensor]) -> torch.Tensor: # Make sure we didn't convert to the acc version self.assertEqual(node.target, operator.getitem) + def test_detach(self): + class TestModule(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.detach(x) + + m = TestModule() + sample_inputs = [torch.randn(8)] + traced = acc_tracer.trace(m, sample_inputs) + + placeholder = output = None + for node in traced.graph.nodes: + if node.op == "placeholder": + assert placeholder is None + placeholder = node + elif node.op == "output": + assert output is None + output = node + else: + raise RuntimeError(f"Unexpected Node {node.format_node()}") + + self.assertIsNotNone(placeholder) + self.assertIsNotNone(output) + + self.assertTrue(torch.equal(m(*sample_inputs), traced(*sample_inputs))) + def test_all_acc_ops_registered(self): self.assertEqual( acc_normalizer._acc_ops, diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index ce2832e9a7..5f02051166 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -162,7 +162,7 @@ def f(x): inputs = [torch.ones(32, 3, 224, 224)] inputs = [i.cuda().half() for i in inputs] torchdynamo.reset() - dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_aten_compiler_fp16)(mod) + dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) dynamo_aten_output = dynamo_aten_mod(*inputs) torchdynamo.reset() diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index 8309db3cf3..1bdc7bf704 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -582,7 +582,9 @@ def stack_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: @register_acc_op_properties(AccOpProperty.pointwise) @register_acc_op_mapping(op_and_target=("call_function", torch.clamp)) +@register_acc_op_mapping(op_and_target=("call_function", torch.clip)) @register_acc_op_mapping(op_and_target=("call_method", "clamp")) +@register_acc_op_mapping(op_and_target=("call_method", "clip")) @register_acc_op def clamp(*, input, min=None, max=None): return torch.clamp(input=input, min=min, max=max) @@ -818,6 +820,10 @@ def matmul(*, input, other): @register_custom_acc_mapper_fn( op_and_target=("call_method", "detach"), arg_replacement_tuples=[("input", "input")] ) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.detach), + arg_replacement_tuples=[("input", "input")], +) def dropout_mapper(node: torch.fx.Node, mod: nn.Module): """ Remove dropout node and directly map its input to output.