Skip to content

feat: Added support for custom torch operators and converters in torchtrtc #1219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/bin/torchtrtc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ cc_binary(
"parser_util.h",
"parser_util.cpp"
],
linkopts = [
"-l:libdl.so"
],
deps = [
"//third_party/args",
"//cpp:torch_tensorrt",
Expand Down
2 changes: 1 addition & 1 deletion cpp/bin/torchtrtc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ add_executable(${executable_name}
if (MSVC)
target_link_libraries(${executable_name} PRIVATE torch torchtrt)
else()
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed" torchtrt "-Wl,--as-needed")
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed -ldl" torchtrt "-Wl,--as-needed")
set_target_properties(
${executable_name}
PROPERTIES INSTALL_RPATH_USE_LINK_PATH FALSE #
Expand Down
13 changes: 13 additions & 0 deletions cpp/bin/torchtrtc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ torchtrtc [input_file_path] [output_file_path]
TorchScript program, save the created
engine to the path specified as the
output path
--custom-torch-ops=[lib] (repeatable) Shared object/DLL containing custom torch operators
--custom-converters=[lib] (repeatable) Shared object/DLL containing custom converters
input_file_path Path to input TorchScript file
output_file_path Path for compiled TorchScript (or
TensorRT engine) file
Expand All @@ -131,3 +133,14 @@ e.g.
```
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
```


To run with custom torch operators
```
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
```

To run with custom converters
```
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
```
96 changes: 93 additions & 3 deletions cpp/bin/torchtrtc/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,33 @@
#include "luts.h"
#include "parser_util.h"

#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif

void* load_library(std::string& custom_lib) {
void* handle = {nullptr};
#if defined(_WIN32)
handle = LoadLibrary(custom_lib.c_str());
#else
handle = dlopen(custom_lib.c_str(), RTLD_LAZY);
#endif
return handle;
}

bool unload_library(void* custom_lib) {
bool success = false;
#if defined(_WIN32)
// Returns status non-zero for success
success = FreeLibrary(custom_lib) ? true : false;
#else
success = dlclose(custom_lib) ? false : true;
#endif
return success;
}

int main(int argc, char** argv) {
torchtrt::logging::set_is_colored_output_on(true);
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING);
Expand Down Expand Up @@ -117,8 +144,7 @@ int main(int argc, char** argv) {
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
args::ValueFlag<uint64_t> workspace_size(
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
args::ValueFlag<uint64_t> dla_sram_size(
parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"});
args::ValueFlag<uint64_t> dla_sram_size(parser, "dla_sram_size", "DLA managed SRAM size", {"dla-sram-size"});
args::ValueFlag<uint64_t> dla_local_dram_size(
parser, "dla_local_dram_size", "DLA Local DRAM size", {"dla-local-dram-size"});
args::ValueFlag<uint64_t> dla_global_dram_size(
Expand Down Expand Up @@ -147,6 +173,18 @@ int main(int argc, char** argv) {
"save_engine",
"Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path",
{"save-engine"});
args::ValueFlagList<std::string> custom_torch_ops(
parser,
"custom-torch-ops",
"(repeatable) Shared object/DLL containing custom torch operators",
{"custom-torch-ops"});

args::ValueFlagList<std::string> custom_converters(
parser,
"custom-converters",
"(repeatable) Shared object/DLL containing custom converters",
{"custom-converters"});

args::Positional<std::string> input_path(parser, "input_file_path", "Path to input TorchScript file");
args::Positional<std::string> output_path(
parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file");
Expand Down Expand Up @@ -174,6 +212,34 @@ int main(int argc, char** argv) {
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR);
}

std::vector<std::pair<std::string, void*>> custom_torch_op, custom_converter_op;
if (custom_torch_ops) {
for (auto& op : args::get(custom_torch_ops)) {
auto* handle = load_library(op);
if (handle == nullptr) {
torchtrt::logging::log(
torchtrt::logging::Level::kERROR, std::string("Could not load custom_torch_ops library " + op));
} else {
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_torch_ops library " + op));

custom_torch_op.push_back({op, handle});
}
}
}

if (custom_converters) {
for (auto& op : args::get(custom_converters)) {
auto* handle = load_library(op);
if (handle == nullptr) {
torchtrt::logging::log(
torchtrt::logging::Level::kERROR, std::string("Could not load custom_converter library " + op));
} else {
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_converter library " + op));
custom_converter_op.push_back({op, handle});
}
}
}

auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path));

if (check_method_op_support) {
Expand All @@ -189,7 +255,7 @@ int main(int argc, char** argv) {
auto method = args::get(check_method_op_support);
auto result = torchtrt::ts::check_method_operator_support(mod, method);
if (result) {
std::cout << "The method is supported end to end by Torch-TensorRT" << std::endl;
torchtrt::logging::log(torchtrt::logging::Level::kINFO, "The method is supported end to end by Torch-TensorRT");
return 0;
} else {
torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Method is not currently supported by Torch-TensorRT");
Expand Down Expand Up @@ -477,5 +543,29 @@ int main(int argc, char** argv) {
trt_mod.save(real_output_path);
}

if (custom_torch_ops) {
for (auto& p : custom_torch_op) {
auto status = unload_library(p.second);
if (status) {
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
} else {
torchtrt::logging::log(
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
}
}
}

if (custom_converters) {
for (auto& p : custom_converter_op) {
auto status = unload_library(p.second);
if (status) {
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
} else {
torchtrt::logging::log(
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
}
}
}

return 0;
}
12 changes: 12 additions & 0 deletions docsrc/tutorials/torchtrtc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
TorchScript program, save the created
engine to the path specified as the
output path
--custom-torch-ops (repeatable) Shared object/DLL containing custom torch operators
--custom-converters (repeatable) Shared object/DLL containing custom converters
input_file_path Path to input TorchScript file
output_file_path Path for compiled TorchScript (or
TensorRT engine) file
Expand All @@ -132,3 +134,13 @@ e.g.
.. code-block:: shell

torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@f16%contiguous" -p f16


To run with custom torch operators
.. code-block:: shell
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16


To run with custom converters
.. code-block:: shell
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16