Skip to content

Commit b6b690c

Browse files
Weiwwei6
andauthored
Changes done internally at Facebook (#1288)
bd46e8f292bf68fe6b87d2d5d206c89fda79a746 Shirong Wu <[email protected]> Disable group ln fuse pass 6ce1d3bc19d75b266e99355c96daeff7054dcbf8 Wei Wei <[email protected]> [fx2trt] set logging level to INFO at fx root 9d552dc3f69db9e4a249f80ef00803a9413e5d38 Wei Wei <[email protected]> [fx2trt] change OSS method lower_to_trt() to compile() Co-authored-by: wwei6 <[email protected]>
1 parent 6f09709 commit b6b690c

File tree

11 files changed

+46
-32
lines changed

11 files changed

+46
-32
lines changed

docsrc/tutorials/getting_started_with_fx_path.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Torch-TensorRT (FX Path) is in ``Beta`` phase and always recommended to work wit
3434
3535
Converting a PyTorch Model to TensorRT Engine
3636
---------------------------------------------
37-
In general, users are welcome to use the ``lower_to_trt()`` to finish the conversion from a model to tensorRT engine. It is a wrapper API that consists of the major steps needed to finish this converison. Please refer to ``lower_example.py`` file in ``examples/fx``.
37+
In general, users are welcome to use the ``compile()`` to finish the conversion from a model to tensorRT engine. It is a wrapper API that consists of the major steps needed to finish this converison. Please refer to ``lower_example.py`` file in ``examples/fx``.
3838

3939
In this section, we will go through an example to illustrate the major steps that FX path uses. Users can refer to ``fx2trt_example.py`` file in ``examples/fx``.
4040

@@ -60,9 +60,9 @@ symbolically traced variables cannot be used as inputs to control flow
6060
This means the model contains dynamic control flow. Please refer to section “Dynamic Control Flow” in `FX guide <https://pytorch.org/docs/stable/fx.html#dynamic-control-flow>`_.
6161

6262
* **Step 2: Build TensorRT engine**
63-
There are `two different modes <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch>`_ for how TensorRT handles batch dimension, explicit batch dimension and implicit batch dimension. This mode was used by early versions of TensorRT, and is now deprecated but continues to be supported for backwards compatibility. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. User can still choose to use implicit batch mode when they set ``explicit_batch_dimension=False`` in ``lower_to_trt()``. We do not recommend to use it since it will lack of support in future TensorRT versions.
63+
There are `two different modes <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#explicit-implicit-batch>`_ for how TensorRT handles batch dimension, explicit batch dimension and implicit batch dimension. This mode was used by early versions of TensorRT, and is now deprecated but continues to be supported for backwards compatibility. In explicit batch mode, all dimensions are explicit and can be dynamic, that is their length can change at execution time. Many new features, such as dynamic shapes and loops, are available only in this mode. User can still choose to use implicit batch mode when they set ``explicit_batch_dimension=False`` in ``compile()``. We do not recommend to use it since it will lack of support in future TensorRT versions.
6464

65-
Explicit batch is the default mode and it must be set for dynamic shape. For most of vision task, user can choose to enable ``dynamic_batch`` in ``lower_to_trt()`` if they want to get the similar effects as implicit mode where only batch dimension changes. It has some requirements:
65+
Explicit batch is the default mode and it must be set for dynamic shape. For most of vision task, user can choose to enable ``dynamic_batch`` in ``compile()`` if they want to get the similar effects as implicit mode where only batch dimension changes. It has some requirements:
6666
1. Shapes of inputs, outputs and activations are fixed except batch dimension.
6767
2. Inputs, outputs and activations have batch dimension as the major dimension.
6868
3. All the operators in the model do not modify batch dimension (permute, transpose, split, etc.) or compute over batch dimension (sum, softmax, etc.).

examples/fx/lower_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66
import torchvision
7-
from torch_tensorrt.fx.lower import lower_to_trt
7+
from torch_tensorrt.fx.lower import compile
88
from torch_tensorrt.fx.utils import LowerPrecision
99

1010

@@ -183,7 +183,7 @@ def run_configuration_benchmark(
183183
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
184184
elif not conf.jit:
185185
# Run lowering eager mode benchmark
186-
lowered_module = lower_to_trt(
186+
lowered_module = compile(
187187
module,
188188
input,
189189
max_batch_size=conf.batch_size,

examples/fx/torchdynamo_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torchdynamo
77
import torchvision
8-
from torch_tensorrt.fx.lower import lower_to_trt
8+
from torch_tensorrt.fx.lower import compile
99
from torch_tensorrt.fx.utils import LowerPrecision
1010
from torchdynamo.optimizations import backends
1111

@@ -197,7 +197,7 @@ def run_configuration_benchmark(
197197

198198
if conf.trt:
199199
# Run lowering eager mode benchmark
200-
lowered_module = lower_to_trt(
200+
lowered_module = compile(
201201
module,
202202
input,
203203
max_batch_size=conf.batch_size,

py/torch_tensorrt/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from enum import Enum
88

99
import torch_tensorrt.fx
10-
from torch_tensorrt.fx.lower import lower_to_trt
10+
import torch_tensorrt.fx.lower
1111
from torch_tensorrt.fx.utils import LowerPrecision
1212

1313

@@ -140,7 +140,7 @@ def compile(
140140
else:
141141
raise ValueError(f"Precision {enabled_precisions} not supported on FX")
142142

143-
return lower_to_trt(
143+
return torch_tensorrt.fx.lower.compile(
144144
module,
145145
inputs,
146146
lower_precision=lower_precision,

py/torch_tensorrt/fx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from .converters import * # noqa: F403 F401
2+
import logging
3+
24
from .converter_registry import ( # noqa
35
CONVERTERS,
46
NO_EXPLICIT_BATCH_DIM_SUPPORT,
@@ -9,3 +11,5 @@
911
from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa
1012
from .lower_setting import LowerSetting # noqa
1113
from .trt_module import TRTModule # noqa
14+
15+
logging.basicConfig(level=logging.INFO)

py/torch_tensorrt/fx/lower.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
Input = Sequence[Any]
2626

2727

28-
def lower_to_trt(
28+
def compile(
2929
module: nn.Module,
3030
input,
3131
max_batch_size: int = 2048,
@@ -216,28 +216,32 @@ def create(
216216
)
217217
)
218218

219-
@decorate_method(validate_inference(atol=1e-1, rtol=1e-1))
220219
def __call__(
221220
self,
222221
module: nn.Module,
223222
inputs: Input,
224223
additional_inputs: Optional[Input] = None,
225224
) -> nn.Module:
226-
module.eval()
227-
228-
if (
229-
self.lower_pass_manager_builder.lower_setting.lower_precision
230-
== LowerPrecision.FP16
231-
):
232-
module.half()
233-
inputs = tuple(
234-
x.half() if x is not None and x.dtype == torch.float32 else x
235-
for x in inputs
225+
lower_setting = self.lower_pass_manager_builder.lower_setting
226+
atol = lower_setting.correctness_atol
227+
rtol = lower_setting.correctness_rtol
228+
229+
@validate_inference(atol=atol, rtol=rtol)
230+
def do_lower(module: nn.Module, inputs: Input) -> nn.Module:
231+
module.eval()
232+
if (
233+
self.lower_pass_manager_builder.lower_setting.lower_precision
234+
== LowerPrecision.FP16
235+
):
236+
module.half()
237+
inputs = tuple(
238+
x.half() if x is not None and x.dtype == torch.float32 else x
239+
for x in inputs
240+
)
241+
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
242+
inputs, additional_inputs
236243
)
237-
pm = self.lower_pass_manager_builder.build_trt_lower_pipeline(
238-
inputs, additional_inputs
239-
)
240-
241-
lower_result = pm(module)
244+
lower_result = pm(module)
245+
return lower_result
242246

243-
return lower_result
247+
return do_lower(module, inputs)

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class LowerSetting(LowerSettingBasic):
7070
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.
7171
tactic_sources: tactic sources for TensorRT kernel selection. Default to None,
7272
meaning all possible tactic sources.
73+
correctness_atol: absolute tolerance for correctness check
74+
correctness_rtol: relative tolerance for correctness check
7375
"""
7476

7577
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -90,3 +92,5 @@ class LowerSetting(LowerSettingBasic):
9092
opt_profile_replica: int = 1
9193
dynamic_batch: bool = True
9294
tactic_sources: Optional[int] = None
95+
correctness_atol: float = 0.1
96+
correctness_rtol: float = 0.1

py/torch_tensorrt/fx/test/converters/acc_op/test_type_as.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import torch
21
import unittest
2+
3+
import torch
34
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
45
from torch.testing._internal.common_utils import run_tests
56
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec

py/torch_tensorrt/fx/test/passes/test_graph_opts.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
99
from torch_tensorrt.fx.passes.graph_opts import common_subexpression_elimination
1010

11-
_LOGGER: logging.Logger = logging.getLogger(__name__)
12-
1311

1412
_LOGGER: logging.Logger = logging.getLogger(__name__)
1513

py/torch_tensorrt/fx/test/trt_lower/test_diagnostics.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import torch_tensorrt.fx.diagnostics as diag
1212

13-
_LOGGER: logging.Logger = logging.getLogger(__name__)
14-
1513

1614
_LOGGER: logging.Logger = logging.getLogger(__name__)
1715

py/torch_tensorrt/fx/tools/trt_profiler_sorted.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def profile_trt_module(
3737
layer_info = json.loads(trt_mod.get_layer_info()) # pyre-ignore[29]
3838
shape_map = {}
3939
for layer in layer_info["Layers"]:
40+
# if type is str, it means verbose_profile is off in interpreter.run()
41+
# Theorectically, we can print profiling information without shape information
42+
# but we choose to not print profiling information so we can use verbose_profile to control it
43+
if type(layer) is str:
44+
return
4045
name = layer["Name"]
4146
input_str = ", ".join(
4247
[str(x.get("Dimensions", "[]")) for x in layer.get("Inputs", [])]

0 commit comments

Comments
 (0)