15
15
#include " luts.h"
16
16
#include " parser_util.h"
17
17
18
+ #if defined(_WIN32)
19
+ #include < windows.h>
20
+ #else
21
+ #include < dlfcn.h>
22
+ #endif
23
+
24
+ void * load_library (std::string& custom_lib) {
25
+ void * handle = {nullptr };
26
+ #if defined(_WIN32)
27
+ handle = LoadLibrary (custom_lib.c_str ());
28
+ #else
29
+ handle = dlopen (custom_lib.c_str (), RTLD_LAZY);
30
+ #endif
31
+ return handle;
32
+ }
33
+
34
+ bool unload_library (void * custom_lib) {
35
+ bool success = false ;
36
+ #if defined(_WIN32)
37
+ // Returns status non-zero for success
38
+ success = FreeLibrary (custom_lib) ? true : false ;
39
+ #else
40
+ success = dlclose (custom_lib) ? false : true ;
41
+ #endif
42
+ return success;
43
+ }
44
+
18
45
int main (int argc, char ** argv) {
19
46
torchtrt::logging::set_is_colored_output_on (true );
20
47
torchtrt::logging::set_reportable_log_level (torchtrt::logging::Level::kWARNING );
@@ -146,6 +173,18 @@ int main(int argc, char** argv) {
146
173
" save_engine" ,
147
174
" Instead of compiling a full a TorchScript program, save the created engine to the path specified as the output path" ,
148
175
{" save-engine" });
176
+ args::ValueFlagList<std::string> custom_torch_ops (
177
+ parser,
178
+ " custom-torch-ops" ,
179
+ " (repeatable) Shared object/DLL containing custom torch operators" ,
180
+ {" custom-torch-ops" });
181
+
182
+ args::ValueFlagList<std::string> custom_converters (
183
+ parser,
184
+ " custom-converters" ,
185
+ " (repeatable) Shared object/DLL containing custom converters" ,
186
+ {" custom-converters" });
187
+
149
188
args::Positional<std::string> input_path (parser, " input_file_path" , " Path to input TorchScript file" );
150
189
args::Positional<std::string> output_path (
151
190
parser, " output_file_path" , " Path for compiled TorchScript (or TensorRT engine) file" );
@@ -173,6 +212,34 @@ int main(int argc, char** argv) {
173
212
torchtrt::logging::set_reportable_log_level (torchtrt::logging::Level::kERROR );
174
213
}
175
214
215
+ std::vector<std::pair<std::string, void *>> custom_torch_op, custom_converter_op;
216
+ if (custom_torch_ops) {
217
+ for (auto & op : args::get (custom_torch_ops)) {
218
+ auto * handle = load_library (op);
219
+ if (handle == nullptr ) {
220
+ torchtrt::logging::log (
221
+ torchtrt::logging::Level::kERROR , std::string (" Could not load custom_torch_ops library " + op));
222
+ } else {
223
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Loaded custom_torch_ops library " + op));
224
+
225
+ custom_torch_op.push_back ({op, handle});
226
+ }
227
+ }
228
+ }
229
+
230
+ if (custom_converters) {
231
+ for (auto & op : args::get (custom_converters)) {
232
+ auto * handle = load_library (op);
233
+ if (handle == nullptr ) {
234
+ torchtrt::logging::log (
235
+ torchtrt::logging::Level::kERROR , std::string (" Could not load custom_converter library " + op));
236
+ } else {
237
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Loaded custom_converter library " + op));
238
+ custom_converter_op.push_back ({op, handle});
239
+ }
240
+ }
241
+ }
242
+
176
243
auto real_input_path = torchtrtc::fileio::resolve_path (args::get (input_path));
177
244
178
245
if (check_method_op_support) {
@@ -188,7 +255,7 @@ int main(int argc, char** argv) {
188
255
auto method = args::get (check_method_op_support);
189
256
auto result = torchtrt::ts::check_method_operator_support (mod, method);
190
257
if (result) {
191
- std::cout << " The method is supported end to end by Torch-TensorRT" << std::endl ;
258
+ torchtrt::logging::log (torchtrt::logging::Level:: kINFO , " The method is supported end to end by Torch-TensorRT" ) ;
192
259
return 0 ;
193
260
} else {
194
261
torchtrt::logging::log (torchtrt::logging::Level::kERROR , " Method is not currently supported by Torch-TensorRT" );
@@ -476,5 +543,29 @@ int main(int argc, char** argv) {
476
543
trt_mod.save (real_output_path);
477
544
}
478
545
546
+ if (custom_torch_ops) {
547
+ for (auto & p : custom_torch_op) {
548
+ auto status = unload_library (p.second );
549
+ if (status) {
550
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + p.first ));
551
+ } else {
552
+ torchtrt::logging::log (
553
+ torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + p.first ));
554
+ }
555
+ }
556
+ }
557
+
558
+ if (custom_converters) {
559
+ for (auto & p : custom_converter_op) {
560
+ auto status = unload_library (p.second );
561
+ if (status) {
562
+ torchtrt::logging::log (torchtrt::logging::Level::kINFO , std::string (" Unloaded custom library " + p.first ));
563
+ } else {
564
+ torchtrt::logging::log (
565
+ torchtrt::logging::Level::kERROR , std::string (" Could not unload custom library " + p.first ));
566
+ }
567
+ }
568
+ }
569
+
479
570
return 0 ;
480
571
}
0 commit comments