Skip to content

Adds support for serialization and deseralization for compiled TorchScript modules #74

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
3381073
refactor(//core/execution): Embed engines in TorchScript modules, now
narendasan May 31, 2020
4f349a1
feat(//cpp/trtorchc): Adding a new CLI application for TRTorch which
narendasan May 31, 2020
670cf21
refactor(//cpp/trtorchexec): Demonstrate serialization in trtorchexec
narendasan May 31, 2020
eb636ec
chore(//third_party/args): Check in argparser library
narendasan May 31, 2020
736e914
feat(//py): register trtorch with torch op library to support
narendasan May 31, 2020
bf651dd
fix(aten::batchnorm|aten::view): Fix converter implementation for
narendasan May 31, 2020
863e0ce
refactor(//core/conversion/tensorcontainer): Set a better namespace for
narendasan May 31, 2020
ff81ebc
refactor(//core/conversion/conversionctx): Document if inputs are
narendasan May 31, 2020
31fd53d
refactor(//core/conversion/converters): add TRTORCH_UNUSED to all
narendasan May 31, 2020
aac8da4
refactor(//tests/modules/hub): Code style fixes
narendasan May 31, 2020
f3370c4
test: Test serialization to make sure it works
narendasan May 31, 2020
d647447
feat(//:libtrtorch): Ship trtorchc with the tarball
narendasan May 31, 2020
e9cef84
refactor(//core/execution): Remove redundant functions, stub out call
narendasan May 31, 2020
1709128
docs: Regenerated docs covering serialization
narendasan May 31, 2020
8b5465d
docs: Update README with serialization instructions
narendasan May 31, 2020
05bf696
docs: Clarification that the library must still be present to run
narendasan May 31, 2020
b763332
refactor(//tests/modules/test_serialization): make the code more clear
narendasan May 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pkg_tar(
"//core/conversion:include",
"//core/conversion/conversionctx:include",
"//core/conversion/converters:include",
"//core/conversion/var:include",
"//core/conversion/tensorcontainer:include",
"//core/conversion/evaluators:include",
"//core/execution:include",
"//core/lowering:include",
Expand Down Expand Up @@ -35,6 +37,15 @@ pkg_tar(
)


pkg_tar(
name = "bin",
package_dir = "bin/",
srcs = [
"//cpp/trtorchc:trtorchc",
],
mode = "0755",
)



pkg_tar(
Expand All @@ -46,6 +57,7 @@ pkg_tar(
],
deps = [
":lib",
":bin",
":include",
":include_core",
],
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ compile_settings.op_precision = torch::kFloat;
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
auto results = trt_mod.forward({in_tensor});
// Save module for later
trt_mod.save("trt_torchscript_module.ts");
...
```

Expand All @@ -46,6 +48,7 @@ trt_ts_module = trtorch.compile(torch_script_module, compile_settings)

input_data = input_data.half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_torchscript_module.ts")
```

> Notes on running in lower precisions:
Expand Down
68 changes: 54 additions & 14 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
#include "NvInfer.h"

#include "ATen/core/function_schema.h"
#include "ATen/core/jit_type.h"

#include "torch/custom_class.h"
#include "torch/csrc/jit/frontend/function_schema_parser.h"
#include "torch/csrc/jit/ir/ir.h"
#include "torch/csrc/jit/passes/pass_manager.h"
Expand Down Expand Up @@ -40,32 +42,70 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str


void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
auto num_io = execution::GetEngineIO(uid);

auto self = g->addInput("self.1");
auto engine = execution::TRTEngine(mod._ivalue()->name(), serialized_engine);
// Get required metadata about the engine out
auto num_io = engine.num_io;
auto name = engine.name;

// Add the engine as an attribute of the module, this will let the engine be serialized and deserialized
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(engine);
mod.register_attribute(
name,
c10::getCustomClassType<c10::intrusive_ptr<execution::TRTEngine>>(),
c10::IValue(std::move(engine_ptr)),
false
);

// Add the module as an input into the graph
auto self = g->addInput("self_1");
self->setType(mod.type());

auto id_val = g->insertConstant(uid);
// Start by retriveing the engine from the module attribute list
auto engine_node = g->createGetAttr(self, name);
g->block()->appendNode(engine_node);

// Add inputs to the graph corresponding to the number of input tensors expected by the engine
// Also store those inputs in a vector so that they can be coalesced into a single list at runtime
std::vector<torch::jit::Value*> engine_inputs;
engine_inputs.push_back(id_val);

for (uint64_t i = 0; i < num_io.first; i++) {
auto in_val = g->addInput("");
auto in_val = g->addInput(std::string("input_") + std::to_string(i));
in_val->setType(c10::TensorType::get());
engine_inputs.push_back(in_val);
}

auto engine_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs), num_io.second);
g->block()->appendNode(engine_node);

if (engine_node->outputs().size() > 1) {
auto return_tuple_node = g->createTuple(engine_node->outputs());
// Create a node that will merge all of the input tensors into a single list argument to the trt::execute_engine op
// Creates: prim::ListConstruct(<input tensors>)
auto input_list_node = g->createList(c10::TensorType::get(), torch::jit::ArrayRef<torch::jit::Value*>(engine_inputs));
g->block()->appendNode(input_list_node);

// Make a list of inputs to the actual trt::execute_engine op
// Note: Ordering of list and then engine is because we can pop off the engine first which contains all the metadata
// needed for execution
std::vector<torch::jit::Value*> execute_node_inputs;
execute_node_inputs.push_back(input_list_node->outputs()[0]);
execute_node_inputs.push_back(engine_node->outputs()[0]);

// Create the actual execution node trt::execute_engine using the assembled inputs
auto execute_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs), 1);
g->block()->appendNode(execute_node);
execute_node->outputs()[0]->setType(c10::ListType::ofTensors());

// Create a node to unpack the list into seperate tensors, in the case of there being only one tensor, the tensor will be returned,
// otherwise they are returned as a tuple of tensors.
// Creates: prim::ListUnpack(<engine output>)
auto unpack_node = g->createListUnpack(execute_node->outputs()[0], num_io.second);
g->block()->appendNode(unpack_node);

// If there are multiple output tensors from TensorRT we wrap them in a tuple to return
if (unpack_node->outputs().size() > 1) {
// Creates prim::TupleConstruct(<output tensors>) using outputs of the unpack node
auto return_tuple_node = g->createTuple(unpack_node->outputs());
g->block()->appendNode(return_tuple_node);
// Set the output as the produced tuple
g->registerOutput(return_tuple_node->outputs()[0]);
} else {
g->registerOutput(engine_node->outputs()[0]);
// Set the output as the sole output tensor
g->registerOutput(unpack_node->outputs()[0]);
}

LOG_DEBUG(*g << "(AddEngineToGraph)\n");
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/InterfaceTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ InputRange::InputRange(std::vector<int64_t> d) {
min = util::toDims(d);
max = util::toDims(d);
input_shape = util::toDims(d);

input_is_dynamic = false;
}


Expand Down Expand Up @@ -67,6 +67,7 @@ InputRange::InputRange(std::vector<int64_t> min_shape, std::vector<int64_t> opt_
dim.insert(max_shape[i]);
if (dim.size() != 1) {
dyn_shape.push_back(-1);
input_is_dynamic = true;
} else {
dyn_shape.push_back(opt_shape[i]);
}
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ void AddInputs(ConversionCtx* ctx,
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kOPT, dims.opt);
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMAX, dims.max);

if (dims.input_is_dynamic) {
ctx->input_is_dynamic = true;
}

ctx->value_tensor_map[in] = trt_in;
}

Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct InputRange {
nvinfer1::Dims max;
nvinfer1::Dims opt;
nvinfer1::Dims input_shape;
bool input_is_dynamic = false;
// Should we restrict to unsigned?
InputRange(std::vector<int64_t> d);
InputRange(std::vector<int64_t> min_shape,
Expand Down
1 change: 1 addition & 0 deletions core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct ConversionCtx {

~ConversionCtx();

bool input_is_dynamic = false;
nvinfer1::IBuilder* builder;
nvinfer1::INetworkDefinition* net;
nvinfer1::IBuilderConfig* cfg;
Expand Down
20 changes: 16 additions & 4 deletions core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
auto orig_shape = input->getDimensions();
auto shape = util::toVec(orig_shape);
auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
auto beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
auto mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
auto var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));

torch::Tensor gamma, beta, mean, var;

if (ctx->input_is_dynamic) {
gamma = args[1].unwrapToTensor();
beta = args[2].unwrapToTensor();
mean = args[3].unwrapToTensor();
var = args[4].unwrapToTensor();
} else {
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
}

auto eps = args[7].unwrapToDouble(1e-5f);


LOG_DEBUG("momentum disregarded");
LOG_DEBUG("training disregarded");
LOG_DEBUG("cudnn disregarded");
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
auto cat_registrations = RegisterNodeConversionPatterns()
auto cat_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::cat(Tensor[] tensors, int dim=0) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
auto constant_registrations = RegisterNodeConversionPatterns()
auto constant_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"trt::const(Tensor self) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/conv_deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace conversion {
namespace converters {
namespace impl {
namespace {
auto conv_registrations = RegisterNodeConversionPatterns()
auto conv_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
R"SIG(aten::_convolution(Tensor input, Tensor weight,
Tensor? bias, int[] stride, int[] padding,
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ nvinfer1::ILayer* add_elementwise(ConversionCtx* ctx, nvinfer1::ElementWiseOpera

}

auto element_wise_registrations = RegisterNodeConversionPatterns()
auto element_wise_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {

auto linear_registrations = RegisterNodeConversionPatterns()
auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::linear(Tensor input, Tensor weight, Tensor? bias = None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {

auto mm_registrations = RegisterNodeConversionPatterns()
auto mm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace converters {
namespace impl {
namespace {

auto pooling_registrations = RegisterNodeConversionPatterns()
auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace {



auto reduce_registrations = RegisterNodeConversionPatterns()
auto reduce_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::mean(Tensor self, *, ScalarType? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace converters {
namespace impl {
namespace {

static auto shape_registrations = RegisterNodeConversionPatterns()
static auto shape_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
// To use in static input size cases (explicit batch)
"aten::size.int(Tensor self, int dim) -> (Tensor)",
Expand Down
6 changes: 2 additions & 4 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace converters {
namespace impl {
namespace {

static auto shuffle_registrations = RegisterNodeConversionPatterns()
static auto shuffle_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down Expand Up @@ -50,12 +50,10 @@ static auto shuffle_registrations = RegisterNodeConversionPatterns()
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto in_shape = util::toVec(in->getDimensions());
auto ex_tensor = torch::rand(in_shape);
auto new_shape = ex_tensor.view(args[1].unwrapToIntList().vec()).sizes();

auto shuffle = ctx->net->addShuffle(*in);
TRTORCH_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
shuffle->setReshapeDimensions(util::toDims(new_shape));
shuffle->setReshapeDimensions(util::toDims(args[1].unwrapToIntList().vec()));
shuffle->setName(util::node_info(n).c_str());

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/converters/impl/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace converters {
namespace impl {
namespace {

static auto softmax_registrations = RegisterNodeConversionPatterns()
static auto softmax_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
.pattern({
"aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/tensorcontainer/TensorContainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace conversion {
namespace {

static auto tensor_container =
torch::class_<TensorContainer>("_eval_ivalue_types", "TensorContainer")
torch::class_<TensorContainer>("_trtorch_eval_ivalue_types", "TensorContainer")
.def(torch::init<>());
} // namespace
} // conversion
Expand Down
2 changes: 1 addition & 1 deletion core/conversion/var/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")

pkg_tar(
name = "include",
package_dir = "core/conversion/arg/",
package_dir = "core/conversion/var/",
srcs = [
"Var.h",
"Var_inl.h"
Expand Down
1 change: 0 additions & 1 deletion core/execution/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ cc_library(
],
srcs = [
"TRTEngine.cpp",
"TRTEngineManager.cpp",
"register_trt_op.cpp",
],
deps = [
Expand Down
Loading