Skip to content

Commit 293db8b

Browse files
authored
Merge pull request #1219 from pytorch/anuragd/torchtrtc-custom-plugins
feat: Added support for custom torch operators and converters in torchtrtc
2 parents 48a7f28 + 74f4475 commit 293db8b

File tree

5 files changed

+121
-2
lines changed

5 files changed

+121
-2
lines changed

cpp/bin/torchtrtc/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ cc_binary(
1919
"parser_util.h",
2020
"parser_util.cpp"
2121
],
22+
linkopts = [
23+
"-l:libdl.so"
24+
],
2225
deps = [
2326
"//third_party/args",
2427
"//cpp:torch_tensorrt",

cpp/bin/torchtrtc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_executable(${executable_name}
1010
if (MSVC)
1111
target_link_libraries(${executable_name} PRIVATE torch torchtrt)
1212
else()
13-
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed" torchtrt "-Wl,--as-needed")
13+
target_link_libraries(${executable_name} PRIVATE torch "-Wl,--no-as-needed -ldl" torchtrt "-Wl,--as-needed")
1414
set_target_properties(
1515
${executable_name}
1616
PROPERTIES INSTALL_RPATH_USE_LINK_PATH FALSE #

cpp/bin/torchtrtc/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ torchtrtc [input_file_path] [output_file_path]
108108
TorchScript program, save the created
109109
engine to the path specified as the
110110
output path
111+
--custom-torch-ops=[lib] (repeatable) Shared object/DLL containing custom torch operators
112+
--custom-converters=[lib] (repeatable) Shared object/DLL containing custom converters
111113
input_file_path Path to input TorchScript file
112114
output_file_path Path for compiled TorchScript (or
113115
TensorRT engine) file
@@ -131,3 +133,14 @@ e.g.
131133
```
132134
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
133135
```
136+
137+
138+
To run with custom torch operators
139+
```
140+
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
141+
```
142+
143+
To run with custom converters
144+
```
145+
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
146+
```

cpp/bin/torchtrtc/main.cpp

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,33 @@
1515
#include "luts.h"
1616
#include "parser_util.h"
1717

18+
#if defined(_WIN32)
19+
#include <windows.h>
20+
#else
21+
#include <dlfcn.h>
22+
#endif
23+
24+
void* load_library(std::string& custom_lib) {
25+
void* handle = {nullptr};
26+
#if defined(_WIN32)
27+
handle = LoadLibrary(custom_lib.c_str());
28+
#else
29+
handle = dlopen(custom_lib.c_str(), RTLD_LAZY);
30+
#endif
31+
return handle;
32+
}
33+
34+
bool unload_library(void* custom_lib) {
35+
bool success = false;
36+
#if defined(_WIN32)
37+
// Returns status non-zero for success
38+
success = FreeLibrary(custom_lib) ? true : false;
39+
#else
40+
success = dlclose(custom_lib) ? false : true;
41+
#endif
42+
return success;
43+
}
44+
1845
int main(int argc, char** argv) {
1946
torchtrt::logging::set_is_colored_output_on(true);
2047
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kWARNING);
@@ -146,6 +173,18 @@ int main(int argc, char** argv) {
146173
"save_engine",
147174
"Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path",
148175
{"save-engine"});
176+
args::ValueFlagList<std::string> custom_torch_ops(
177+
parser,
178+
"custom-torch-ops",
179+
"(repeatable) Shared object/DLL containing custom torch operators",
180+
{"custom-torch-ops"});
181+
182+
args::ValueFlagList<std::string> custom_converters(
183+
parser,
184+
"custom-converters",
185+
"(repeatable) Shared object/DLL containing custom converters",
186+
{"custom-converters"});
187+
149188
args::Positional<std::string> input_path(parser, "input_file_path", "Path to input TorchScript file");
150189
args::Positional<std::string> output_path(
151190
parser, "output_file_path", "Path for compiled TorchScript (or TensorRT engine) file");
@@ -173,6 +212,34 @@ int main(int argc, char** argv) {
173212
torchtrt::logging::set_reportable_log_level(torchtrt::logging::Level::kERROR);
174213
}
175214

215+
std::vector<std::pair<std::string, void*>> custom_torch_op, custom_converter_op;
216+
if (custom_torch_ops) {
217+
for (auto& op : args::get(custom_torch_ops)) {
218+
auto* handle = load_library(op);
219+
if (handle == nullptr) {
220+
torchtrt::logging::log(
221+
torchtrt::logging::Level::kERROR, std::string("Could not load custom_torch_ops library " + op));
222+
} else {
223+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_torch_ops library " + op));
224+
225+
custom_torch_op.push_back({op, handle});
226+
}
227+
}
228+
}
229+
230+
if (custom_converters) {
231+
for (auto& op : args::get(custom_converters)) {
232+
auto* handle = load_library(op);
233+
if (handle == nullptr) {
234+
torchtrt::logging::log(
235+
torchtrt::logging::Level::kERROR, std::string("Could not load custom_converter library " + op));
236+
} else {
237+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Loaded custom_converter library " + op));
238+
custom_converter_op.push_back({op, handle});
239+
}
240+
}
241+
}
242+
176243
auto real_input_path = torchtrtc::fileio::resolve_path(args::get(input_path));
177244

178245
if (check_method_op_support) {
@@ -188,7 +255,7 @@ int main(int argc, char** argv) {
188255
auto method = args::get(check_method_op_support);
189256
auto result = torchtrt::ts::check_method_operator_support(mod, method);
190257
if (result) {
191-
std::cout << "The method is supported end to end by Torch-TensorRT" << std::endl;
258+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, "The method is supported end to end by Torch-TensorRT");
192259
return 0;
193260
} else {
194261
torchtrt::logging::log(torchtrt::logging::Level::kERROR, "Method is not currently supported by Torch-TensorRT");
@@ -476,5 +543,29 @@ int main(int argc, char** argv) {
476543
trt_mod.save(real_output_path);
477544
}
478545

546+
if (custom_torch_ops) {
547+
for (auto& p : custom_torch_op) {
548+
auto status = unload_library(p.second);
549+
if (status) {
550+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
551+
} else {
552+
torchtrt::logging::log(
553+
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
554+
}
555+
}
556+
}
557+
558+
if (custom_converters) {
559+
for (auto& p : custom_converter_op) {
560+
auto status = unload_library(p.second);
561+
if (status) {
562+
torchtrt::logging::log(torchtrt::logging::Level::kINFO, std::string("Unloaded custom library " + p.first));
563+
} else {
564+
torchtrt::logging::log(
565+
torchtrt::logging::Level::kERROR, std::string("Could not unload custom library " + p.first));
566+
}
567+
}
568+
}
569+
479570
return 0;
480571
}

docsrc/tutorials/torchtrtc.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
111111
TorchScript program, save the created
112112
engine to the path specified as the
113113
output path
114+
--custom-torch-ops (repeatable) Shared object/DLL containing custom torch operators
115+
--custom-converters (repeatable) Shared object/DLL containing custom converters
114116
input_file_path Path to input TorchScript file
115117
output_file_path Path for compiled TorchScript (or
116118
TensorRT engine) file
@@ -132,3 +134,13 @@ e.g.
132134
.. code-block:: shell
133135
134136
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@f16%contiguous" -p f16
137+
138+
139+
To run with custom torch operators
140+
.. code-block:: shell
141+
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
142+
143+
144+
To run with custom converters
145+
.. code-block:: shell
146+
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-converters=<path to custom library> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16

0 commit comments

Comments
 (0)