Skip to content

Commit db20098

Browse files
authored
Merge pull request #74 from NVIDIA/seralization
Adds support for serialization and deseralization for compiled TorchScript modules
2 parents 40564c3 + b763332 commit db20098

File tree

119 files changed

+6989
-347
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+6989
-347
lines changed

BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ pkg_tar(
88
"//core/conversion:include",
99
"//core/conversion/conversionctx:include",
1010
"//core/conversion/converters:include",
11+
"//core/conversion/var:include",
12+
"//core/conversion/tensorcontainer:include",
1113
"//core/conversion/evaluators:include",
1214
"//core/execution:include",
1315
"//core/lowering:include",
@@ -35,6 +37,15 @@ pkg_tar(
3537
)
3638

3739

40+
pkg_tar(
41+
name = "bin",
42+
package_dir = "bin/",
43+
srcs = [
44+
"//cpp/trtorchc:trtorchc",
45+
],
46+
mode = "0755",
47+
)
48+
3849

3950

4051
pkg_tar(
@@ -46,6 +57,7 @@ pkg_tar(
4657
],
4758
deps = [
4859
":lib",
60+
":bin",
4961
":include",
5062
":include_core",
5163
],

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ compile_settings.op_precision = torch::kFloat;
2323
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
2424
// Run like normal
2525
auto results = trt_mod.forward({in_tensor});
26+
// Save module for later
27+
trt_mod.save("trt_torchscript_module.ts");
2628
...
2729
```
2830
@@ -46,6 +48,7 @@ trt_ts_module = trtorch.compile(torch_script_module, compile_settings)
4648
4749
input_data = input_data.half()
4850
result = trt_ts_module(input_data)
51+
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
4952
```
5053

5154
> Notes on running in lower precisions:

core/compiler.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include "NvInfer.h"
77

88
#include "ATen/core/function_schema.h"
9+
#include "ATen/core/jit_type.h"
910

11+
#include "torch/custom_class.h"
1012
#include "torch/csrc/jit/frontend/function_schema_parser.h"
1113
#include "torch/csrc/jit/ir/ir.h"
1214
#include "torch/csrc/jit/passes/pass_manager.h"
@@ -40,32 +42,70 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4042

4143

4244
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
43-
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
44-
auto num_io = execution::GetEngineIO(uid);
45-
46-
auto self = g->addInput("self.1");
45+
auto engine = execution::TRTEngine(mod._ivalue()->name(), serialized_engine);
46+
// Get required metadata about the engine out
47+
auto num_io = engine.num_io;
48+
auto name = engine.name;
49+
50+
// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
51+
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(engine);
52+
mod.register_attribute(
53+
name,
54+
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
55+
c10::IValue(std::move(engine_ptr)),
56+
false
57+
);
58+
59+
// Add the module as an input into the graph
60+
auto self = g->addInput("self_1");
4761
self->setType(mod.type());
4862

49-
auto id_val = g->insertConstant(uid);
63+
// Start by retriveing the engine from the module attribute list
64+
auto engine_node = g->createGetAttr(self, name);
65+
g->block()->appendNode(engine_node);
5066

67+
// Add inputs to the graph corresponding to the number of input tensors expected by the engine
68+
// Also store those inputs in a vector so that they can be coalesced into a single list at runtime
5169
std::vector<torch::jit::Value*> engine_inputs;
52-
engine_inputs.push_back(id_val);
53-
5470
for (uint64_t i = 0; i < num_io.first; i++) {
55-
auto in_val = g->addInput("");
71+
auto in_val = g->addInput(std::string("input_") + std::to_string(i));
5672
in_val->setType(c10::TensorType::get());
5773
engine_inputs.push_back(in_val);
5874
}
5975

60-
auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs), num_io.second);
61-
g->block()->appendNode(engine_node);
62-
63-
if (engine_node->outputs().size() > 1) {
64-
auto return_tuple_node = g->createTuple(engine_node->outputs());
76+
// Create a node that will merge all of the input tensors into a single list argument to the trt::execute_engine op
77+
// Creates: prim::ListConstruct(<input tensors>)
78+
auto input_list_node = g->createList(c10::TensorType::get(), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs));
79+
g->block()->appendNode(input_list_node);
80+
81+
// Make a list of inputs to the actual trt::execute_engine op
82+
// Note: Ordering of list and then engine is because we can pop off the engine first which contains all the metadata
83+
// needed for execution
84+
std::vector<torch::jit::Value*> execute_node_inputs;
85+
execute_node_inputs.push_back(input_list_node->outputs()[0]);
86+
execute_node_inputs.push_back(engine_node->outputs()[0]);
87+
88+
// Create the actual execution node trt::execute_engine using the assembled inputs
89+
auto execute_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs), 1);
90+
g->block()->appendNode(execute_node);
91+
execute_node->outputs()[0]->setType(c10::ListType::ofTensors());
92+
93+
// Create a node to unpack the list into seperate tensors, in the case of there being only one tensor, the tensor will be returned,
94+
// otherwise they are returned as a tuple of tensors.
95+
// Creates: prim::ListUnpack(<engine output>)
96+
auto unpack_node = g->createListUnpack(execute_node->outputs()[0], num_io.second);
97+
g->block()->appendNode(unpack_node);
98+
99+
// If there are multiple output tensors from TensorRT we wrap them in a tuple to return
100+
if (unpack_node->outputs().size() > 1) {
101+
// Creates prim::TupleConstruct(<output tensors>) using outputs of the unpack node
102+
auto return_tuple_node = g->createTuple(unpack_node->outputs());
65103
g->block()->appendNode(return_tuple_node);
104+
// Set the output as the produced tuple
66105
g->registerOutput(return_tuple_node->outputs()[0]);
67106
} else {
68-
g->registerOutput(engine_node->outputs()[0]);
107+
// Set the output as the sole output tensor
108+
g->registerOutput(unpack_node->outputs()[0]);
69109
}
70110

71111
LOG_DEBUG(*g << "(AddEngineToGraph)\n");

core/conversion/InterfaceTypes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ InputRange::InputRange(std::vector<int64_t> d) {
3434
min = util::toDims(d);
3535
max = util::toDims(d);
3636
input_shape = util::toDims(d);
37-
37+
input_is_dynamic = false;
3838
}
3939

4040

@@ -67,6 +67,7 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
6767
dim.insert(max_shape[i]);
6868
if (dim.size() != 1) {
6969
dyn_shape.push_back(-1);
70+
input_is_dynamic = true;
7071
} else {
7172
dyn_shape.push_back(opt_shape[i]);
7273
}

core/conversion/conversion.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ void AddInputs(ConversionCtx* ctx,
155155
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
156156
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);
157157

158+
if (dims.input_is_dynamic) {
159+
ctx->input_is_dynamic = true;
160+
}
161+
158162
ctx->value_tensor_map[in] = trt_in;
159163
}
160164

core/conversion/conversion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ struct InputRange {
1515
nvinfer1::Dims max;
1616
nvinfer1::Dims opt;
1717
nvinfer1::Dims input_shape;
18+
bool input_is_dynamic = false;
1819
// Should we restrict to unsigned?
1920
InputRange(std::vector<int64_t> d);
2021
InputRange(std::vector<int64_t> min_shape,

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ struct ConversionCtx {
4242

4343
~ConversionCtx();
4444

45+
bool input_is_dynamic = false;
4546
nvinfer1::IBuilder* builder;
4647
nvinfer1::INetworkDefinition* net;
4748
nvinfer1::IBuilderConfig* cfg;

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1919
auto orig_shape = input->getDimensions();
2020
auto shape = util::toVec(orig_shape);
2121
auto options = torch::TensorOptions().dtype(torch::kFloat32);
22-
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
23-
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
24-
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
25-
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
22+
23+
torch::Tensor gamma, beta, mean, var;
24+
25+
if (ctx->input_is_dynamic) {
26+
gamma = args[1].unwrapToTensor();
27+
beta = args[2].unwrapToTensor();
28+
mean = args[3].unwrapToTensor();
29+
var = args[4].unwrapToTensor();
30+
} else {
31+
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
32+
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
33+
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
34+
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
35+
}
36+
2637
auto eps = args[7].unwrapToDouble(1e-5f);
2738

39+
2840
LOG_DEBUG("momentum disregarded");
2941
LOG_DEBUG("training disregarded");
3042
LOG_DEBUG("cudnn disregarded");

core/conversion/converters/impl/concat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace conversion {
88
namespace converters {
99
namespace impl {
1010
namespace {
11-
auto cat_registrations = RegisterNodeConversionPatterns()
11+
auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/constant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace conversion {
77
namespace converters {
88
namespace impl {
99
namespace {
10-
auto constant_registrations = RegisterNodeConversionPatterns()
10+
auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1111
.pattern({
1212
"trt::const(Tensor self) -> Tensor",
1313
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace conversion {
99
namespace converters {
1010
namespace impl {
1111
namespace {
12-
auto conv_registrations = RegisterNodeConversionPatterns()
12+
auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1313
.pattern({
1414
R"SIG(aten::_convolution(Tensor input, Tensor weight,
1515
Tensor? bias, int[] stride, int[] padding,

core/conversion/converters/impl/element_wise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera
6868

6969
}
7070

71-
auto element_wise_registrations = RegisterNodeConversionPatterns()
71+
auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7272
.pattern({
7373
"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor",
7474
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto linear_registrations = RegisterNodeConversionPatterns()
11+
auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto mm_registrations = RegisterNodeConversionPatterns()
11+
auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/pooling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace converters {
88
namespace impl {
99
namespace {
1010

11-
auto pooling_registrations = RegisterNodeConversionPatterns()
11+
auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1212
.pattern({
1313
"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)",
1414
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/reduce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace {
1111

1212

1313

14-
auto reduce_registrations = RegisterNodeConversionPatterns()
14+
auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1515
.pattern({
1616
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)",
1717
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/converters/impl/shape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12-
static auto shape_registrations = RegisterNodeConversionPatterns()
12+
static auto shape_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1313
.pattern({
1414
// To use in static input size cases (explicit batch)
1515
"aten::size.int(Tensor self, int dim) -> (Tensor)",

core/conversion/converters/impl/shuffle.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12-
static auto shuffle_registrations = RegisterNodeConversionPatterns()
12+
static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1313
.pattern({
1414
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
1515
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
@@ -50,12 +50,10 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
5050
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
5151
auto in = args[0].ITensor();
5252
auto in_shape = util::toVec(in->getDimensions());
53-
auto ex_tensor = torch::rand(in_shape);
54-
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();
5553

5654
auto shuffle = ctx->net->addShuffle(*in);
5755
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
58-
shuffle->setReshapeDimensions(util::toDims(new_shape));
56+
shuffle->setReshapeDimensions(util::toDims(args[1].unwrapToIntList().vec()));
5957
shuffle->setName(util::node_info(n).c_str());
6058

6159
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));

core/conversion/converters/impl/softmax.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace converters {
77
namespace impl {
88
namespace {
99

10-
static auto softmax_registrations = RegisterNodeConversionPatterns()
10+
static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1111
.pattern({
1212
"aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)",
1313
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {

core/conversion/tensorcontainer/TensorContainer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace conversion {
66
namespace {
77

88
static auto tensor_container =
9-
torch::class_<TensorContainer>("_eval_ivalue_types", "TensorContainer")
9+
torch::class_<TensorContainer>("_trtorch_eval_ivalue_types", "TensorContainer")
1010
.def(torch::init<>());
1111
} // namespace
1212
} // conversion

core/conversion/var/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
3030

3131
pkg_tar(
3232
name = "include",
33-
package_dir = "core/conversion/arg/",
33+
package_dir = "core/conversion/var/",
3434
srcs = [
3535
"Var.h",
3636
"Var_inl.h"

core/execution/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ cc_library(
1414
],
1515
srcs = [
1616
"TRTEngine.cpp",
17-
"TRTEngineManager.cpp",
1817
"register_trt_op.cpp",
1918
],
2019
deps = [

0 commit comments

Comments
 (0)