Skip to content

Commit dfbf9b5

Browse files
author
Wei
authored
Merge pull request #1143 from pytorch/fb-sync-wwei6
[fx2trt] Engineholder feature improvement, test fixes
2 parents 17898d0 + a9d36f2 commit dfbf9b5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+504
-196
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
command: |
4848
pip3 install nvidia-pyindex
4949
pip3 install nvidia-tensorrt==8.2.4.2
50-
pip3 install --pre torch==1.13.0.dev20220618 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
50+
pip3 install --pre torch==1.13.0.dev20220621 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
5151
pip3 install pytest parameterized expecttest
5252
# install torch_tensorrt
5353
mv WORKSPACE.ci WORKSPACE

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ def acc_ops_conv1d(
9292
kernel=weight,
9393
bias=bias,
9494
)
95-
padding = kwargs["padding"]
96-
padding = padding + (0,)
97-
stride = extend_attr_to_tuple(kwargs["stride"], 1)
98-
dilation = extend_attr_to_tuple(kwargs["dilation"], 1)
95+
# expand params to 2d for computation
96+
padding = list(kwargs["padding"])
97+
padding.append(0)
98+
stride = extend_attr_to_tuple(kwargs["stride"], 2)
99+
dilation = extend_attr_to_tuple(kwargs["dilation"], 2)
100+
99101
set_layer_name(layer, target, name)
100102
layer.stride_nd = stride
101103
layer.padding_nd = padding

py/torch_tensorrt/fx/converters/convolution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ def common_conv(network, mod, dimension, input_val, layer_name, is_quantized):
3232
unsqueeze_layer.name = f"{layer_name}_unsqueeze"
3333
input_val = unsqueeze_layer.get_output(0)
3434

35-
padding = padding + (0,)
3635
kernel = np.expand_dims(kernel, -1)
3736
kernel_size = kernel.shape[2:]
3837
if bias is not None:
3938
bias = bias[None]
40-
# bias = np.expand_dims(bias, -1)
41-
39+
stride = (stride[0], 1)
40+
padding = (padding[0], 0)
41+
dilation = (dilation[0], 1)
4242
layer = network.add_convolution_nd(
4343
input=input_val,
4444
num_output_maps=mod.out_channels,

py/torch_tensorrt/fx/input_tensor_spec.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ def from_tensors(cls, tensors: Sequence[torch.Tensor]) -> List["InputTensorSpec"
6666

6767
@classmethod
6868
def from_tensors_with_dynamic_batch_size(
69-
cls, tensors: Sequence[torch.Tensor], batch_size_range: Tuple[int, int, int]
69+
cls,
70+
tensors: Sequence[torch.Tensor],
71+
batch_size_range: Tuple[int, int, int],
72+
opt_profile_replica: int = 1,
7073
) -> List["InputTensorSpec"]:
7174
"""
7275
Produce a list of InputTenosrSpec named tuples which would contain
@@ -93,7 +96,7 @@ def from_tensors_with_dynamic_batch_size(
9396
), f"The {i}th tensor (shape: {tensor.shape}) doesn't have the correct batch size: {batch_size}."
9497
shape = list(tensor.shape)
9598
shape[0] = -1
96-
shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] # type: ignore[list-item]
99+
shape_ranges: List[ShapeRange] = [tuple(tuple([bs] + shape[1:]) for bs in batch_size_range)] * opt_profile_replica # type: ignore[list-item]
97100
input_specs.append(
98101
cls(tuple(shape), tensor.dtype, tensor.device, shape_ranges)
99102
)

py/torch_tensorrt/fx/lower.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
100100
self.lower_setting.max_batch_size,
101101
self.lower_setting.max_batch_size,
102102
),
103+
self.lower_setting.opt_profile_replica,
103104
)
104105
if self.lower_setting.explicit_batch_dimension
105106
else InputTensorSpec.from_tensors(input)

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class LowerSetting(LowerSettingBasic):
6969
how presets are applied. Refer to
7070
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
7171
to add a preset.
72+
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
73+
only used by explicit batch dim with dynamic shape mode.
7274
"""
7375

7476
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -86,3 +88,4 @@ class LowerSetting(LowerSettingBasic):
8688
save_timing_cache: bool = False
8789
cuda_graph_batch_size: int = -1
8890
preset_lowerer: str = ""
91+
opt_profile_replica: int = 1

py/torch_tensorrt/fx/passes/lower_basic_pass.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def skip_folding_quant_dequant(node: torch.fx.Node):
3131
return True
3232
return False
3333

34-
const_split_mod = split_const_subgraphs(
35-
traced_mod, skip_folding_quant_dequant, device_for_folded_attrs="cuda"
36-
)
34+
const_split_mod = split_const_subgraphs(traced_mod, skip_folding_quant_dequant)
3735
const_split_mod.run_folding()
3836
return const_split_mod
3937

py/torch_tensorrt/fx/test/converters/acc_op/test_adaptive_avgpool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import parameterized
4-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
54
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
66

77

88
class TestAdaptiveAvgPoolConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_any.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase
65
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
77

88

99
class TestAnyConverters(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_as_strided.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase
65
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
77

88

99
class TestConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_avgpool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import param, parameterized
4-
from torch.testing._internal.common_fx2trt import AccTestCase
54
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
66

77

88
class TestAvgPoolConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_batchnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
3-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
43
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
55

66

77
class TestBatchNormConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
77
from parameterized import parameterized
8-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
98
from torch.testing._internal.common_utils import run_tests
9+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
1010

1111
NEED_TEST_BOTH_CONSTANTS_CASE = True
1212

py/torch_tensorrt/fx/test/converters/acc_op/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
4-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
54
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
66

77

88
class TestCatConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_chunk.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import torch.nn as nn
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
5-
from torch.testing._internal.common_fx2trt import AccTestCase, InputTensorSpec
65
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
77

88

99
class TestChunkConverter(AccTestCase):

py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import torch
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import param, parameterized
4-
from torch.testing._internal.common_fx2trt import AccTestCase
54
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
66

77

88
class TestClampConverter(AccTestCase):

0 commit comments

Comments
 (0)