Skip to content

Commit 32a6e67

Browse files
peri044zewenli98
andauthored
chore: Upgrade TensorRT version to TRT 10 EA (#2699)
Co-authored-by: Evan Li <[email protected]>
1 parent d859859 commit 32a6e67

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+433
-392
lines changed

.github/scripts/install-torch-tensorrt.sh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ source ${BUILD_ENV_FILE}
55
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision
66
${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0
77
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
8-
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com
8+
9+
# Install TensorRT manually
10+
wget -q -P /opt/torch-tensorrt-builds/ https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.0/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz
11+
tar -xzf /opt/torch-tensorrt-builds/TensorRT-10.0.0.6.Linux.x86_64-gnu.cuda-12.4.tar.gz -C /opt/torch-tensorrt-builds/
12+
python -m pip install /opt/torch-tensorrt-builds/TensorRT-10.0.0.6/python/tensorrt-10.0.0b6-cp${PYTHON_VERSION//./}-none-linux_x86_64.whl
13+
14+
# Install Torch-TensorRT
15+
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl
916

1017
echo -e "Running test script";

.github/workflows/build-test.yml

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ on:
1515

1616
jobs:
1717
generate-matrix:
18-
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
18+
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@release/2.3
1919
with:
2020
package-type: wheel
2121
os: linux
@@ -37,11 +37,11 @@ jobs:
3737
- repository: pytorch/tensorrt
3838
pre-script: packaging/pre_build_script.sh
3939
env-var-script: packaging/env_vars.txt
40-
post-script: ""
41-
smoke-test-script: ""
40+
post-script: packaging/post_build_script.sh
41+
smoke-test-script: packaging/smoke_test_script.sh
4242
package-name: torch_tensorrt
4343
name: Build torch-tensorrt whl package
44-
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
44+
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@release/2.3
4545
with:
4646
repository: ${{ matrix.repository }}
4747
ref: ""
@@ -65,7 +65,8 @@ jobs:
6565
- repository: pytorch/tensorrt
6666
package-name: torch_tensorrt
6767
pre-script: packaging/pre_build_script.sh
68-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
68+
post-script: packaging/post_build_script.sh
69+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
6970
with:
7071
job-name: tests-py-torchscript-fe
7172
repository: "pytorch/tensorrt"
@@ -77,9 +78,11 @@ jobs:
7778
script: |
7879
export USE_HOST_DEPS=1
7980
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
81+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
8082
pushd .
8183
cd tests/modules
82-
${CONDA_RUN} python -m pip install --pre -r requirements.txt --use-deprecated=legacy-resolver
84+
# Don't use requirements.txt here as it contains tensorrt and torch which should have been installed by now.
85+
${CONDA_RUN} python -m pip install numpy packaging pyyaml transformers timm pybind11==2.6.2
8386
${CONDA_RUN} python hub.py
8487
popd
8588
pushd .
@@ -100,7 +103,8 @@ jobs:
100103
- repository: pytorch/tensorrt
101104
package-name: torch_tensorrt
102105
pre-script: packaging/pre_build_script.sh
103-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
106+
post-script: packaging/post_build_script.sh
107+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
104108
with:
105109
job-name: tests-py-dynamo-converters
106110
repository: "pytorch/tensorrt"
@@ -111,6 +115,7 @@ jobs:
111115
pre-script: ${{ matrix.pre-script }}
112116
script: |
113117
export USE_HOST_DEPS=1
118+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
114119
pushd .
115120
cd tests/py/dynamo
116121
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
@@ -127,7 +132,8 @@ jobs:
127132
- repository: pytorch/tensorrt
128133
package-name: torch_tensorrt
129134
pre-script: packaging/pre_build_script.sh
130-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
135+
post-script: packaging/post_build_script.sh
136+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
131137
with:
132138
job-name: tests-py-dynamo-fe
133139
repository: "pytorch/tensorrt"
@@ -138,6 +144,7 @@ jobs:
138144
pre-script: ${{ matrix.pre-script }}
139145
script: |
140146
export USE_HOST_DEPS=1
147+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
141148
pushd .
142149
cd tests/py/dynamo
143150
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
@@ -155,7 +162,8 @@ jobs:
155162
- repository: pytorch/tensorrt
156163
package-name: torch_tensorrt
157164
pre-script: packaging/pre_build_script.sh
158-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
165+
post-script: packaging/post_build_script.sh
166+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
159167
with:
160168
job-name: tests-py-dynamo-serde
161169
repository: "pytorch/tensorrt"
@@ -166,6 +174,7 @@ jobs:
166174
pre-script: ${{ matrix.pre-script }}
167175
script: |
168176
export USE_HOST_DEPS=1
177+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
169178
pushd .
170179
cd tests/py/dynamo
171180
${CONDA_RUN} python -m pip install --pre pytest timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
@@ -182,7 +191,8 @@ jobs:
182191
- repository: pytorch/tensorrt
183192
package-name: torch_tensorrt
184193
pre-script: packaging/pre_build_script.sh
185-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
194+
post-script: packaging/post_build_script.sh
195+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
186196
with:
187197
job-name: tests-py-torch-compile-be
188198
repository: "pytorch/tensorrt"
@@ -193,6 +203,7 @@ jobs:
193203
pre-script: ${{ matrix.pre-script }}
194204
script: |
195205
export USE_HOST_DEPS=1
206+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
196207
pushd .
197208
cd tests/py/dynamo
198209
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
@@ -211,7 +222,8 @@ jobs:
211222
- repository: pytorch/tensorrt
212223
package-name: torch_tensorrt
213224
pre-script: packaging/pre_build_script.sh
214-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
225+
post-script: packaging/post_build_script.sh
226+
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
215227
with:
216228
job-name: tests-py-dynamo-core
217229
repository: "pytorch/tensorrt"
@@ -222,6 +234,7 @@ jobs:
222234
pre-script: ${{ matrix.pre-script }}
223235
script: |
224236
export USE_HOST_DEPS=1
237+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
225238
pushd .
226239
cd tests/py/dynamo
227240
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
@@ -251,6 +264,7 @@ jobs:
251264
pre-script: ${{ matrix.pre-script }}
252265
script: |
253266
export USE_HOST_DEPS=1
267+
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
254268
pushd .
255269
cd tests/py/core
256270
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,10 @@ torch.jit.save(trt_ts_module, "trt_torchscript_module.ts") # save the TRT embedd
116116
These are the following dependencies used to verify the testcases. Torch-TensorRT can work with other versions, but the tests are not guaranteed to pass.
117117

118118
- Bazel 5.2.0
119-
- Libtorch 2.3.0.dev (latest nightly) (built with CUDA 12.1)
119+
- Libtorch 2.3.0 (built with CUDA 12.1)
120120
- CUDA 12.1
121121
- cuDNN 8.9.5
122-
- TensorRT 8.6.1
122+
- TensorRT 10.0.0.6
123123

124124
## Prebuilt Binaries and Wheel files
125125

core/conversion/converters/converter_util.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ nvinfer1::ITensor* addPadding(
3939
}
4040
}
4141

42+
nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name) {
43+
nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0);
44+
input_shape = castITensor(ctx, input_shape, nvinfer1::DataType::kINT32, name);
45+
return input_shape;
46+
}
47+
4248
nvinfer1::ITensor* addUnpadding(
4349
ConversionCtx* ctx,
4450
const torch::jit::Node* n,
@@ -134,7 +140,7 @@ nvinfer1::ILayer* add_elementwise(
134140
}
135141
auto otherStaticShapeMask = tensor_to_const(ctx, thOtherStaticShapeMask);
136142
auto otherDynamicShapeMask = tensor_to_const(ctx, thOtherDynamicShapeMask);
137-
auto selfShape = ctx->net->addShape(*self)->getOutput(0);
143+
nvinfer1::ITensor* selfShape = getShapeOutput(ctx, self, std::string(name + "_shape_cast").c_str());
138144
// size of dynamic dimension of other need to the same as that of
139145
// corresponding dimension of self
140146
auto otherDynamicShape =
@@ -348,7 +354,6 @@ nvinfer1::ITensor* normalize_indices(
348354
auto neg_itensor = tensor_to_const(ctx, neg);
349355
// find the indices that = -1
350356
auto signs = clamp(ctx, indices, neg_itensor, zero_itensor, "clamp layer for " + name);
351-
352357
// get the inputDim value where indices == -1, else 0
353358
auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, signs, input_dim, "prod layer for " + name);
354359
TORCHTRT_CHECK(mul, "Unable to create mul layer in normalize_indices");

core/conversion/converters/converter_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ nvinfer1::ITensor* castITensor(
6262
nvinfer1::DataType dtype,
6363
const std::string& layer_name_prefix = "");
6464

65+
// Get the shape of the input tensor and cast it to INT32 type
66+
nvinfer1::ITensor* getShapeOutput(ConversionCtx* ctx, nvinfer1::ITensor* input_tensor, const std::string& name = "");
67+
6568
// Freeze an at::Tensor in a IConstant layer
6669
nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string());
6770

core/conversion/converters/impl/chunk.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
1717
auto chunks = args[1].unwrapToInt();
1818
auto dim = args[2].unwrapToInt();
1919
bool dynamic_shape = ctx->input_is_dynamic;
20-
int size = in->getDimensions().nbDims;
2120
int maxDim = static_cast<int32_t>(in->getDimensions().d[dim]);
2221

2322
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
@@ -41,9 +40,6 @@ auto cat_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
4140
size_.nbDims = nbdims;
4241
stride_.nbDims = nbdims;
4342

44-
int startIdx = 0;
45-
int endIdx = maxDim;
46-
4743
for (int i = 0; i < nbdims; i++) {
4844
start_.d[i] = 0;
4945
size_.d[i] = 0;

core/conversion/converters/impl/constant_pad.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,15 @@ auto constant_pad_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns
5555
util::toDims(c10::IntArrayRef(stride)));
5656
TORCHTRT_CHECK(slice_layer, "Unable to create slice layer from node: " << *n);
5757
slice_layer->setName((util::node_info(n) + "_slice").c_str());
58-
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
58+
slice_layer->setMode(nvinfer1::SampleMode::kFILL);
5959
slice_layer->setInput(4, *value_itensor);
6060

6161
if (ctx->input_is_dynamic) {
6262
// build the size using inetwork layers
63-
auto shape_layer = ctx->net->addShape(*in);
64-
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
65-
shape_layer->setName((util::node_info(n) + "_shape").c_str());
6663
auto total_padding_itensor = tensor_to_const(ctx, torch::tensor(total_padding, torch::kInt32));
67-
68-
auto add_layer = ctx->net->addElementWise(
69-
*shape_layer->getOutput(0), *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
64+
nvinfer1::ITensor* shapeOutput = getShapeOutput(ctx, in, (util::node_info(n) + "_shape").c_str());
65+
auto add_layer =
66+
ctx->net->addElementWise(*shapeOutput, *total_padding_itensor, nvinfer1::ElementWiseOperation::kSUM);
7067
TORCHTRT_CHECK(add_layer, "Unable to create add layer from node: " << *n);
7168
add_layer->setName((util::node_info(n) + "_add").c_str());
7269
slice_layer->setInput(2, *add_layer->getOutput(0));

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ nvinfer1::ILayer* add_bias_layer(
3333
nvinfer1::Dims& input_dims,
3434
nvinfer1::Dims& output_padding,
3535
Weights& bias) {
36-
nvinfer1::ITensor* input_shape = ctx->net->addShape(*input_tensor)->getOutput(0);
36+
nvinfer1::ITensor* input_shape = getShapeOutput(ctx, input_tensor, std::string("bias_shape_cast").c_str());
3737
// Add padding layer
3838
nvinfer1::ITensor* start;
3939
nvinfer1::ITensor* totalPadding;
@@ -61,7 +61,7 @@ nvinfer1::ILayer* add_bias_layer(
6161
auto* sliceLayer = ctx->net->addSlice(*input_tensor, dummy, dummy, stride);
6262
sliceLayer->setInput(1, *start);
6363
sliceLayer->setInput(2, *size);
64-
sliceLayer->setMode(nvinfer1::SliceMode::kFILL);
64+
sliceLayer->setMode(nvinfer1::SampleMode::kFILL);
6565
nvinfer1::ITensor* slice_output = sliceLayer->getOutput(0);
6666

6767
nvinfer1::Dims constantDims;
@@ -146,9 +146,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
146146
// TensorRT expects nbSpatialDims = 2 or 3
147147
filter_dim = util::unsqueezeDims(filter_dim, filter_dim.nbDims, 1, false);
148148
// Reshape input dimensions
149-
in = addPadding(ctx, n, in, 4);
149+
in = addPadding(ctx, n, in, 4, true, true, std::string(util::node_info(n) + "_input_shuffle"));
150150
LOG_DEBUG("Reshaping input dimensions to: " << in->getDimensions());
151-
kernel = addPadding(ctx, n, kernel, 4);
151+
kernel = addPadding(ctx, n, kernel, 4, true, true, std::string(util::node_info(n) + "_kernel_shuffle"));
152152
LOG_DEBUG("Reshaping kernel dimensions to: " << kernel->getDimensions());
153153
if (transposed) {
154154
num_output_maps = kernel_dims.d[1];
@@ -194,7 +194,7 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
194194
nvinfer1::IConvolutionLayer* convLayer =
195195
ctx->net->addConvolutionNd(*in, num_output_maps, filter_dim, kernel_weights, bias.data);
196196
convLayer->setStrideNd(stride);
197-
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
197+
convLayer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
198198
convLayer->setPaddingNd(padding);
199199
convLayer->setPostPadding(out_padding);
200200
convLayer->setDilationNd(dilation);
@@ -291,11 +291,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
291291
// shape of convolution's weight: [out, in/groups, ...]
292292
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);
293293
TORCHTRT_CHECK(conv, "Unable to create convolution layer from node: " << *n);
294-
295294
conv->setStrideNd(stride);
296-
conv->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
295+
conv->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN);
297296
conv->setPaddingNd(padding);
298-
conv->setPostPadding(out_padding);
299297
conv->setDilationNd(dilation);
300298
conv->setNbGroups(groups);
301299
new_layer = conv;

core/conversion/converters/impl/cumsum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ auto cumsum_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pat
3636
torch::Tensor axis = torch::tensor(input_dims.d[dim], torch::kInt32);
3737
tripLimit = tensor_to_const(ctx, axis);
3838
} else {
39-
nvinfer1::ITensor* inpShape = ctx->net->addShape(*in)->getOutput(0);
39+
nvinfer1::ITensor* inpShape = getShapeOutput(ctx, in);
4040
torch::Tensor dimValue = torch::tensor(dim, torch::kInt32);
4141
nvinfer1::ITensor* axis = tensor_to_const(ctx, dimValue);
4242
tripLimit = ctx->net->addGather(*inpShape, *axis, 0)->getOutput(0);

core/conversion/converters/impl/expand.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe
1919
if (max_rank - old_rank > 0) {
2020
torch::Tensor thOne = torch::tensor(std::vector<int32_t>(max_rank - old_rank, 1), torch::kInt32);
2121
auto one_tensor = tensor_to_const(ctx, thOne);
22-
auto in_shape_tensor = ctx->net->addShape(*tensor)->getOutput(0);
22+
auto in_shape_tensor = getShapeOutput(ctx, tensor);
2323
nvinfer1::ITensor* const args[2] = {one_tensor, in_shape_tensor};
2424
return ctx->net->addConcatenation(args, 2)->getOutput(0);
2525
} else { // max_rank - old_rank == 0
26-
return ctx->net->addShape(*tensor)->getOutput(0);
26+
return getShapeOutput(ctx, tensor);
2727
}
2828
}
2929

@@ -221,8 +221,7 @@ auto expand_registrations TORCHTRT_UNUSED =
221221
auto targetDims = targetTensor->getDimensions();
222222
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
223223
if (ctx->input_is_dynamic) {
224-
return add_expand_dynamic(
225-
ctx, n, in, ctx->net->addShape(*targetTensor)->getOutput(0), targetDims, false);
224+
return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false);
226225
} else {
227226
return add_expand(ctx, n, in, targetDims);
228227
}
@@ -357,7 +356,7 @@ auto expand_registrations TORCHTRT_UNUSED =
357356
if (ctx->input_is_dynamic) {
358357
auto start_tensor = tensor_to_const(ctx, torch::tensor(start_vec, torch::kInt32));
359358

360-
auto expand_output_shape = ctx->net->addShape(*expand->getOutput(0))->getOutput(0);
359+
auto expand_output_shape = getShapeOutput(ctx, expand->getOutput(0));
361360
std::vector<int64_t> repeat_const_vec(repeat_shape_dims.nbDims, 1);
362361
repeat_const_vec[dim + 1] = repeats;
363362
auto repeat_const = tensor_to_const(ctx, torch::tensor(repeat_const_vec, torch::kInt32));

0 commit comments

Comments
 (0)