Skip to content

Commit 4ffd022

Browse files
authored
[TensorRT EP] Refactor of TRT plugins support (#17946)
Make sure "trt.plugins" custom op domain only being registered once. The bottom line is "trt.plugins" custom op domain needs to be registered before model load. `CreateTensorRTCustomOpDomainList()` is TRT EP's function to create "trt.plugins" custom op domain. Following are places where this function will be called. (This function only fetches all the TRT plugins from TRT plugin registry but not yet registered them to ORT custom op registry. The real registration happens in AddCustomOpDomains()) C/C++ APIs: - `OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_XX`: This function will make session option object contain the "trt.plugins" custom op domain for ORT to register. So that later the session creation api can register the custom op domain accordingly and won't complain about invalid onnx node. - `InferenceSession::RegisterExecutionProvider`: In some cases, users might create the session object first and later call session_object.RegisterExecutionProvider(). This function will call p_exec_provider->GetCustomOpDomainList() which returns "trt.plugins" custom op domain. Otherwise, session_object.Load(model) will complain. Python APIs: - `RegisterTensorRTPluginsAsCustomOps`: Need to call this function so that session option object contains the "trt.plugins" custom op domain for ORT to register. Different language bindings have slightly different workflow of initializing the session. This might cause duplicate custom op domain in `session_option.custom_op_domains_` or `CreateTensorRTCustomOpDomainList()` being called more than once, but we put checks to make sure ep's custom op domain won't be registered twice.
1 parent 2c50b75 commit 4ffd022

7 files changed

Lines changed: 88 additions & 32 deletions

File tree

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,12 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
12101210
}
12111211

12121212
void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
1213+
if (info_.custom_op_domain_list.empty()) {
1214+
common::Status status = CreateTensorRTCustomOpDomainList(info_);
1215+
if (!status.IsOK()) {
1216+
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
1217+
}
1218+
}
12131219
custom_op_domain_list = info_.custom_op_domain_list;
12141220
}
12151221

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
197197
Status ReplayGraph() override;
198198

199199
private:
200-
TensorrtExecutionProviderInfo info_;
200+
mutable TensorrtExecutionProviderInfo info_;
201201
bool external_stream_ = false;
202202
cudaStream_t stream_ = nullptr;
203203
int max_partition_iterations_ = 1000;

onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ struct Tensorrt_Provider : Provider {
7575
info.device_id = device_id;
7676
info.has_trt_options = false;
7777

78-
common::Status status = CreateTensorRTCustomOpDomainList(info);
79-
if (!status.IsOK()) {
80-
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
81-
}
82-
8378
return std::make_shared<TensorrtProviderFactory>(info);
8479
}
8580

@@ -121,11 +116,6 @@ struct Tensorrt_Provider : Provider {
121116
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
122117
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;
123118

124-
common::Status status = CreateTensorRTCustomOpDomainList(info);
125-
if (!status.IsOK()) {
126-
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
127-
}
128-
129119
return std::make_shared<TensorrtProviderFactory>(info);
130120
}
131121

onnxruntime/core/session/inference_session.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,9 +613,35 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
613613
}
614614

615615
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
616-
// Create Custom Op if EP requests it
616+
// Register Custom Op if EP requests it
617617
std::vector<OrtCustomOpDomain*> custom_op_domains;
618-
p_exec_provider->GetCustomOpDomainList(custom_op_domains);
618+
std::vector<OrtCustomOpDomain*> candidate_custom_op_domains;
619+
p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains);
620+
621+
auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type());
622+
623+
// Register the custom op domain only if it has not been registered before
624+
if (registry_kernels.empty()) {
625+
custom_op_domains = candidate_custom_op_domains;
626+
} else {
627+
for (auto candidate_custom_op_domain : candidate_custom_op_domains) {
628+
for (auto registry_kernel : registry_kernels) {
629+
const auto& kernel_map = registry_kernel->GetKernelCreateMap();
630+
bool need_register = true;
631+
// If the kernel registry is the ep's custom op registry, we only need to check the first kernel,
632+
// because all kernels in one kernel registry should have the same domain name.
633+
for (auto iter = kernel_map.begin(); iter != kernel_map.end(); iter++) {
634+
if (iter->second.kernel_def->Domain() == candidate_custom_op_domain->domain_) {
635+
need_register = false;
636+
break;
637+
}
638+
}
639+
if (need_register) {
640+
custom_op_domains.push_back(candidate_custom_op_domain);
641+
}
642+
}
643+
}
644+
}
619645

620646
if (!custom_op_domains.empty()) {
621647
if (AddCustomOpDomains(custom_op_domains) != Status::OK()) {

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,28 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op
16251625

16261626
} // namespace onnxruntime
16271627

1628+
void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) {
1629+
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
1630+
for (auto ptr : domains) {
1631+
if (domain_name == ptr->domain_) {
1632+
return true;
1633+
}
1634+
}
1635+
return false;
1636+
};
1637+
1638+
std::vector<OrtCustomOpDomain*> custom_op_domains;
1639+
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
1640+
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
1641+
for (auto ptr : custom_op_domains) {
1642+
if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) {
1643+
options->custom_op_domains_.push_back(ptr);
1644+
} else {
1645+
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
1646+
}
1647+
}
1648+
}
1649+
16281650
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) {
16291651
API_IMPL_BEGIN
16301652
auto factory = onnxruntime::DnnlProviderFactoryCreator::Create(use_arena);
@@ -1646,13 +1668,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS
16461668

16471669
options->provider_factories.push_back(factory);
16481670

1649-
std::vector<OrtCustomOpDomain*> custom_op_domains;
16501671
std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
1651-
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
1652-
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
1653-
for (auto ptr : custom_op_domains) {
1654-
options->custom_op_domains_.push_back(ptr);
1655-
}
1672+
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);
16561673

16571674
return nullptr;
16581675
API_IMPL_END
@@ -1679,12 +1696,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In
16791696

16801697
options->provider_factories.push_back(factory);
16811698

1682-
std::vector<OrtCustomOpDomain*> custom_op_domains;
1683-
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
1684-
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, "");
1685-
for (auto ptr : custom_op_domains) {
1686-
options->custom_op_domains_.push_back(ptr);
1687-
}
1699+
AddTensorRTCustomOpDomainToSessionOption(options, "");
16881700

16891701
return nullptr;
16901702
API_IMPL_END
@@ -1788,13 +1800,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
17881800

17891801
options->provider_factories.push_back(factory);
17901802

1791-
std::vector<OrtCustomOpDomain*> custom_op_domains;
17921803
std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths;
1793-
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
1794-
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
1795-
for (auto ptr : custom_op_domains) {
1796-
options->custom_op_domains_.push_back(ptr);
1797-
}
1804+
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);
17981805

17991806
return nullptr;
18001807
API_IMPL_END

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,15 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM*
433433
#ifdef USE_TENSORRT
434434
void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) {
435435
if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) {
436+
auto is_already_in_domains = [&](std::string& domain_name, std::vector<OrtCustomOpDomain*>& domains) {
437+
for (auto ptr : domains) {
438+
if (domain_name == ptr->domain_) {
439+
return true;
440+
}
441+
}
442+
return false;
443+
};
444+
436445
std::string trt_extra_plugin_lib_paths = "";
437446
const auto it = options.find("trt_extra_plugin_lib_paths");
438447
if (it != options.end()) {
@@ -441,7 +450,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
441450
std::vector<OrtCustomOpDomain*> domain_list;
442451
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
443452
for (auto ptr : domain_list) {
444-
so.custom_op_domains_.push_back(ptr);
453+
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
454+
so.custom_op_domains_.push_back(ptr);
455+
} else {
456+
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
457+
}
445458
}
446459
} else {
447460
ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported.");

onnxruntime/test/python/onnxruntime_test_python.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,20 @@ def test_set_providers_with_options(self):
298298
self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path))
299299
self.assertEqual(option["trt_force_sequential_engine_build"], "1")
300300

301+
from onnxruntime.capi import _pybind_state as C
302+
303+
session_options = C.get_default_session_options()
304+
305+
# TRT plugins registered as custom op domain should only be added once in session option regardless of number of session creation
306+
sess1 = onnxrt.InferenceSession(
307+
get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"]
308+
)
309+
sess2 = onnxrt.InferenceSession(
310+
get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"]
311+
)
312+
self.assertIn("TensorrtExecutionProvider", sess1.get_providers())
313+
self.assertIn("TensorrtExecutionProvider", sess2.get_providers())
314+
301315
# We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability
302316

303317
"""

0 commit comments

Comments
 (0)