Skip to content

Commit fa4cbcd

Browse files
authored
[TensorRT EP] Add new provider option to exclude nodes from running on TRT (#22681)
Add new provider option `trt_op_types_to_exclude`: - User can provide op type list to be excluded from running on TRT - e.g. `trt_op_types_to_exclude="MaxPool"` There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7. TRT EP excludes DDS ops from running on TRT by default, user can override default value with empty string to include all ops.
1 parent 3adcf4d commit fa4cbcd

File tree

9 files changed

+178
-39
lines changed

9 files changed

+178
-39
lines changed

include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,23 @@ struct OrtTensorRTProviderOptionsV2 {
7171
* directory by means of the "trt_onnx_model_folder_path" option.
7272
*
7373
*/
74-
int trt_dump_ep_context_model{0}; // Dump EP context node model
75-
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
76-
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
77-
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
78-
// nonzero = true
79-
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
80-
// the ONNX model containing the weights (applicable only when
81-
// the "trt_weight_stripped_engine_enable" option is enabled)
82-
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
83-
// (applicable only when the "trt_weight_stripped_engine_enable"
84-
// option is enabled)
85-
// can be updated using: UpdateTensorRTProviderOptionsWithValue
86-
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
87-
// can be updated using: UpdateTensorRTProviderOptionsWithValue
88-
89-
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
90-
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
74+
int trt_dump_ep_context_model{0}; // Dump EP context node model
75+
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
76+
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
77+
int trt_weight_stripped_engine_enable{0}; // Enable weight-stripped engine build. Default 0 = false,
78+
// nonzero = true
79+
const char* trt_onnx_model_folder_path{nullptr}; // Folder path relative to the current working directory for
80+
// the ONNX model containing the weights (applicable only when
81+
// the "trt_weight_stripped_engine_enable" option is enabled)
82+
const void* trt_onnx_bytestream{nullptr}; // The byte stream of th original ONNX model containing the weights
83+
// (applicable only when the "trt_weight_stripped_engine_enable"
84+
// option is enabled)
85+
// can be updated using: UpdateTensorRTProviderOptionsWithValue
86+
size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream"
87+
// can be updated using: UpdateTensorRTProviderOptionsWithValue
88+
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
89+
int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true
90+
const char* trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"}; // Exclude specific ops from running on TRT.
91+
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
92+
// TRT EP excludes DDS ops from running on TRT by default, user can override default value with empty string to include all ops.
9193
};

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
13791379
profile_opt_shapes = info.profile_opt_shapes;
13801380
cuda_graph_enable_ = info.cuda_graph_enable;
13811381
engine_hw_compatible_ = info.engine_hw_compatible;
1382+
op_types_to_exclude_ = info.op_types_to_exclude;
1383+
13821384
} else {
13831385
try {
13841386
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
@@ -1565,6 +1567,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
15651567
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
15661568
}
15671569

1570+
const std::string op_types_to_exclude_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kOpTypesToExclude);
1571+
if (!op_types_to_exclude_env.empty()) {
1572+
op_types_to_exclude_ = op_types_to_exclude_env;
1573+
}
1574+
15681575
} catch (const std::invalid_argument& ex) {
15691576
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
15701577
} catch (const std::out_of_range& ex) {
@@ -1725,6 +1732,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
17251732
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));
17261733
}
17271734

1735+
trt_version_ = getInferLibVersion();
1736+
1737+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_;
1738+
17281739
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
17291740
<< "device_id: " << device_id_
17301741
<< ", trt_max_partition_iterations: " << max_partition_iterations_
@@ -1762,7 +1773,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
17621773
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
17631774
<< ", trt_cache_prefix: " << cache_prefix_
17641775
<< ", trt_engine_hw_compatible: " << engine_hw_compatible_
1765-
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_;
1776+
<< ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_
1777+
<< ", trt_op_types_to_exclude: " << op_types_to_exclude_;
17661778
}
17671779

17681780
TensorrtExecutionProvider::~TensorrtExecutionProvider() {
@@ -2430,6 +2442,18 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
24302442
return cycle_detected;
24312443
}
24322444

2445+
std::set<std::string> GetExcludedNodeSet(std::string node_list_to_exclude) {
2446+
std::set<std::string> set;
2447+
if (!node_list_to_exclude.empty()) {
2448+
std::stringstream node_list(node_list_to_exclude);
2449+
std::string node;
2450+
while (std::getline(node_list, node, ',')) {
2451+
set.insert(node);
2452+
}
2453+
}
2454+
return set;
2455+
}
2456+
24332457
std::vector<std::unique_ptr<ComputeCapability>>
24342458
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
24352459
const IKernelLookup& /*kernel_lookup*/) const {
@@ -2462,10 +2486,27 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
24622486
std::vector<size_t> nodes_vector(number_of_ort_nodes);
24632487
std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0);
24642488

2465-
std::vector<size_t> filtered_nodes_vector;
2489+
std::set<std::string> exclude_set = GetExcludedNodeSet(op_types_to_exclude_);
2490+
2491+
// Print excluded nodes, if any.
2492+
std::set<std::string>::iterator it;
2493+
for (it = exclude_set.begin(); it != exclude_set.end(); ++it) {
2494+
std::string op = *it;
2495+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Exclude \"" << op << "\" from running on TRT, if any.";
2496+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Remove \"" << op << "\" from trt_op_types_to_exclude or specify trt_op_types_to_exclude with empty string to include the op in the input to TRT parser. However, it still depends on TRT parser to determine the eligibility of this op for TRT.";
2497+
}
2498+
2499+
SubGraphCollection_t parser_nodes_vector, supported_nodes_vector;
24662500
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
2501+
bool new_subgraph = true;
2502+
2503+
/* Iterate all the nodes and exclude the node if:
2504+
* 1. It's a control flow op and its subgraph(s) is not fully TRT eligible.
2505+
* 2. It's in the exlucded set which specified by trt_op_types_to_exclude.
2506+
*/
24672507
for (const auto& index : nodes_vector) {
24682508
const auto& node = graph.GetNode(node_index[index]);
2509+
bool supported_node = true;
24692510

24702511
/* If current node is control flow op, we take different approach based on following four cases:
24712512
*
@@ -2477,29 +2518,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
24772518
* For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP.
24782519
*/
24792520
if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) {
2480-
auto sub_graphs = node->GetSubgraphs();
2481-
if (sub_graphs.size() != 0) {
2482-
bool all_subgraphs_are_supported = true;
2483-
for (auto sub_graph : sub_graphs) {
2484-
// TRT EP should consider the empty subgraph is fully supported by TRT.
2485-
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
2486-
continue;
2487-
}
2488-
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
2489-
all_subgraphs_are_supported = false;
2490-
break;
2521+
auto supported_control_flow_op = [&](const Node* node) {
2522+
auto sub_graphs = node->GetSubgraphs();
2523+
if (sub_graphs.size() != 0) {
2524+
for (auto sub_graph : sub_graphs) {
2525+
// TRT EP should consider the empty subgraph is fully supported by TRT.
2526+
if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) {
2527+
continue;
2528+
}
2529+
if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) {
2530+
// if not all its subgraphs are supported, we need to exclude this control flow op
2531+
return false;
2532+
}
24912533
}
24922534
}
2493-
if (!all_subgraphs_are_supported) {
2494-
// if not all its subgraphs are supported, we need to exclude this control flow op
2495-
continue;
2496-
}
2535+
return true;
2536+
};
2537+
supported_node = supported_control_flow_op(node);
2538+
}
2539+
2540+
// Exclude any ops, if applicable
2541+
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
2542+
supported_node = false;
2543+
}
2544+
2545+
if (supported_node) {
2546+
if (new_subgraph) {
2547+
parser_nodes_vector.emplace_back();
2548+
// Mark all new graphs as "UnKnown" which will later be parsed by TRT parser
2549+
parser_nodes_vector.back().second = false;
2550+
new_subgraph = false;
24972551
}
2552+
parser_nodes_vector.back().first.emplace_back(index);
2553+
} else {
2554+
new_subgraph = true;
24982555
}
2499-
filtered_nodes_vector.push_back(index);
25002556
}
25012557

2502-
SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}};
25032558
bool early_termination = false;
25042559
supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination);
25052560
if (early_termination) {

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
5757
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
5858
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
5959
static const std::string kEngineCachePrefix = "ORT_TENSORRT_CACHE_PREFIX";
60+
static const std::string kOpTypesToExclude = "ORT_TENSORRT_OP_TYPES_TO_EXCLUDE";
6061
// Old env variable for backward compatibility
6162
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
6263
} // namespace tensorrt_env_vars
@@ -329,6 +330,10 @@ class TensorrtExecutionProvider : public IExecutionProvider {
329330
bool cuda_graph_enable_ = false;
330331
std::string cache_prefix_;
331332
bool engine_hw_compatible_ = false;
333+
std::string op_types_to_exclude_;
334+
335+
// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
336+
int32_t trt_version_;
332337

333338
// The OrtAllocator object will be get during ep compute time
334339
// and should be kept for the lifetime of TRT EP object.

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
5656
constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible";
5757
constexpr const char* kONNXBytestream = "trt_onnx_bytestream";
5858
constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size";
59+
constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude";
5960

6061
} // namespace provider_option_names
6162
} // namespace tensorrt
@@ -134,6 +135,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
134135
return Status::OK();
135136
})
136137
.AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size)
138+
.AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude)
137139
.Parse(options)); // add new provider option here.
138140

139141
info.user_compute_stream = user_compute_stream;
@@ -188,6 +190,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
188190
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)},
189191
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)},
190192
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)},
193+
{tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)},
191194
};
192195
return options;
193196
}
@@ -206,6 +209,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
206209
const std::string kProfilesOptShapes_ = empty_if_null(info.trt_profile_opt_shapes);
207210
const std::string kEpContextFilePath_ = empty_if_null(info.trt_ep_context_file_path);
208211
const std::string kOnnxModelFolderPath_ = empty_if_null(info.trt_onnx_model_folder_path);
212+
const std::string kOpTypesToExclude_ = empty_if_null(info.trt_op_types_to_exclude);
209213

210214
const ProviderOptions options{
211215
{tensorrt::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
@@ -251,6 +255,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
251255
{tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)},
252256
{tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast<size_t>(info.trt_onnx_bytestream))},
253257
{tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)},
258+
{tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_},
254259
};
255260
return options;
256261
}
@@ -355,5 +360,6 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
355360
trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible;
356361
trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream;
357362
trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size;
363+
trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude);
358364
}
359365
} // namespace onnxruntime

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ struct TensorrtExecutionProviderInfo {
6060
int ep_context_embed_mode{0};
6161
std::string engine_cache_prefix{""};
6262
bool engine_hw_compatible{false};
63+
// There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) from TRT versions 10.0 to 10.7.
64+
// TRT EP excludes DDS ops from running on TRT by default, user can override default value of trt_op_types_to_exclude with empty string to include all ops.
65+
std::string op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
6366

6467
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
6568
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);

onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ struct Tensorrt_Provider : Provider {
118118
info.engine_hw_compatible = options.trt_engine_hw_compatible != 0;
119119
info.onnx_bytestream = options.trt_onnx_bytestream;
120120
info.onnx_bytestream_size = options.trt_onnx_bytestream_size;
121+
info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude;
121122

122123
return std::make_shared<TensorrtProviderFactory>(info);
123124
}

onnxruntime/core/session/provider_bridge_ort.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,8 +2294,11 @@ ORT_API_STATUS_IMPL(OrtApis::UpdateTensorRTProviderOptions,
22942294
#ifdef USE_TENSORRT
22952295
onnxruntime::ProviderOptions provider_options_map;
22962296
for (size_t i = 0; i != num_keys; ++i) {
2297-
if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' ||
2298-
provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') {
2297+
// Don't allow key and value to be empty except the value of trt_op_types_to_exclude
2298+
if (provider_options_keys[i] == nullptr ||
2299+
provider_options_keys[i][0] == '\0' ||
2300+
(provider_options_values[i] == nullptr && strcmp("trt_op_types_to_exclude", provider_options_keys[i])) ||
2301+
(provider_options_values[i][0] == '\0' && strcmp("trt_op_types_to_exclude", provider_options_keys[i]))) {
22992302
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "key/value cannot be empty");
23002303
}
23012304

@@ -2410,6 +2413,7 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor
24102413
delete[] ptr->trt_profile_opt_shapes;
24112414
delete[] ptr->trt_ep_context_file_path;
24122415
delete[] ptr->trt_onnx_model_folder_path;
2416+
if (!ptr->trt_op_types_to_exclude) delete[] ptr->trt_op_types_to_exclude;
24132417
}
24142418

24152419
std::unique_ptr<OrtTensorRTProviderOptionsV2> p(ptr);

onnxruntime/python/onnxruntime_pybind_state.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
526526
// and TRT EP instance, so it won't be released.)
527527
std::string calibration_table, cache_path, cache_prefix, timing_cache_path, lib_path, trt_tactic_sources,
528528
trt_extra_plugin_lib_paths, min_profile, max_profile, opt_profile, ep_context_file_path,
529-
onnx_model_folder_path;
529+
onnx_model_folder_path, trt_op_types_to_exclude{"NonMaxSuppression,NonZero,RoiAlign"};
530530
auto it = provider_options_map.find(type);
531531
if (it != provider_options_map.end()) {
532532
OrtTensorRTProviderOptionsV2 params;
@@ -824,6 +824,9 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
824824
} else {
825825
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_engine_hw_compatible' should be 'True' or 'False'. Default value is 'False'.\n");
826826
}
827+
} else if (option.first == "trt_op_types_to_exclude") {
828+
trt_op_types_to_exclude = option.second;
829+
params.trt_op_types_to_exclude = trt_op_types_to_exclude.c_str();
827830
} else {
828831
ORT_THROW("Invalid TensorRT EP option: ", option.first);
829832
}

0 commit comments

Comments
 (0)