Skip to content

Commit 67d3ba0

Browse files
[EP ABI] Check if nodes specified in GetCapability() have already been assigned (#26156)
### Description Fixes segfault in `PluginExecutionProvider::GetCapability()` when the underlying `OrtEp` tries to claim nodes that have already been assigned to another EP. ### Motivation and Context Should log a warning (instead of crashing or throwing an exception) when a plugin EP tries to claim a node that is already assigned to another EP. --------- Co-authored-by: Edward Chen <[email protected]>
1 parent 2109547 commit 67d3ba0

File tree

2 files changed

+285
-11
lines changed

2 files changed

+285
-11
lines changed

onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"
55

6+
#include <gsl/gsl>
67
#include <memory>
78
#include <string>
89
#include <unordered_set>
@@ -117,6 +118,17 @@ static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_
117118
return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
118119
}
119120

121+
static const Node* FindFirstNodeAssignedToOtherEP(const std::string& ep_type,
122+
gsl::span<const EpNode* const> ep_nodes) {
123+
auto node_iter = std::find_if(ep_nodes.begin(), ep_nodes.end(),
124+
[&ep_type](const EpNode* node) -> bool {
125+
const auto& node_ep_type = node->GetInternalNode().GetExecutionProviderType();
126+
return !node_ep_type.empty() && node_ep_type != ep_type;
127+
});
128+
129+
return node_iter != ep_nodes.end() ? &(*node_iter)->GetInternalNode() : nullptr;
130+
}
131+
120132
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options,
121133
OrtEpFactory& ep_factory,
122134
gsl::span<const OrtEpDevice* const> ep_devices,
@@ -158,17 +170,19 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
158170
ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs
159171
ORT_UNUSED_PARAMETER(kernel_lookup); // TODO: Add support? Not used by prioritized EPs, so probably not needed?
160172

173+
const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger();
174+
161175
std::unique_ptr<EpGraph> ep_graph = nullptr;
162176
if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) {
163-
LOGS_DEFAULT(ERROR) << "Failed to create OrtGraph: " << status.ToString();
177+
LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString();
164178
return {};
165179
}
166180

167181
OrtEpGraphSupportInfo api_graph_support_info(*ep_graph);
168182
Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info));
169183

170184
if (!status.IsOK()) {
171-
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() failed with error: " << status.ToString();
185+
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " failed with error: " << status.ToString();
172186
return {};
173187
}
174188

@@ -182,12 +196,39 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
182196

183197
// Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances.
184198
for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) {
199+
// Skip this node grouping if any node has already been assigned to another EP.
200+
if (const Node* node_for_other_ep = FindFirstNodeAssignedToOtherEP(Type(), node_grouping.nodes);
201+
node_for_other_ep != nullptr) {
202+
LOGS(logger, WARNING) << "OrtEp::GetCapability() specified nodes that cannot be assigned to " << Type() << ". "
203+
<< "Found one or more nodes that were already assigned to a different EP named '"
204+
<< node_for_other_ep->GetExecutionProviderType() << "'. Ex: "
205+
<< node_for_other_ep->OpType() << " node with name '"
206+
<< node_for_other_ep->Name() << "'.";
207+
continue;
208+
}
209+
185210
if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) {
211+
if (node_grouping.nodes.size() != 1) {
212+
// The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide
213+
// an invalid node. However, we check here too just in case this changes.
214+
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node "
215+
<< "when calling EpGraphSupportInfo_AddSingleNode().";
216+
return {};
217+
}
218+
186219
auto indexed_sub_graph = std::make_unique<IndexedSubGraph>();
187220

188221
indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index());
189222
result.push_back(std::make_unique<ComputeCapability>(std::move(indexed_sub_graph)));
190223
} else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) {
224+
if (node_grouping.nodes.empty()) {
225+
// The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide
226+
// an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes.
227+
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes "
228+
<< "when specifying supported nodes.";
229+
return {};
230+
}
231+
191232
std::unordered_set<const Node*> node_set;
192233
node_set.reserve(node_grouping.nodes.size());
193234

@@ -207,27 +248,29 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
207248
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
208249
node_grouping.fusion_options.drop_constant_initializers);
209250

210-
if (capabilities.size() > 1) {
211-
LOGS_DEFAULT(ERROR) << "OrtEp::GetCapability() set nodes that cannot be fused together. "
212-
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddFusedNodes() do not "
251+
if (capabilities.size() != 1) {
252+
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set nodes that cannot be fused together. "
253+
<< "Please ensure that the nodes provided to EpGraphSupportInfo_AddNodesToFuse() do not "
213254
<< "have an unsupported node in any path between two of the supported nodes.";
214255
return {};
215256
}
216257

217-
// Enforce that the nodes in node_set match the nodes in capabilities[0]
258+
// Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always
259+
// be true because we've already checked that the EP did not try to claim nodes already assigned to another EP.
218260
// TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above.
219261
std::vector<NodeIndex>& capability_node_indices = capabilities[0]->sub_graph->nodes;
220262
std::unordered_set<NodeIndex> capability_node_indices_set(capability_node_indices.begin(),
221263
capability_node_indices.end());
222264

223-
ORT_ENFORCE(node_set.size() == capability_node_indices_set.size());
224-
ORT_ENFORCE(std::all_of(node_set.begin(), node_set.end(), [&capability_node_indices_set](const Node* node) {
225-
return capability_node_indices_set.count(node->Index()) != 0;
226-
}));
265+
if (node_set.size() != capability_node_indices_set.size()) {
266+
LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type()
267+
<< " set nodes that cannot all be fused together.";
268+
return {};
269+
}
227270

228271
result.push_back(std::move(capabilities[0]));
229272
} else {
230-
LOGS_DEFAULT(ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
273+
LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: "
231274
<< static_cast<int>(node_grouping.kind);
232275
return {};
233276
}

onnxruntime/test/framework/ep_plugin_provider_test.cc

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,14 @@
33

44
#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h"
55

6+
#include <filesystem>
67
#include "gsl/gsl"
78
#include "gtest/gtest.h"
89

10+
#include "core/common/logging/sinks/file_sink.h"
11+
#include "core/graph/graph_viewer.h"
12+
#include "core/graph/model.h"
13+
#include "core/optimizer/graph_optimizer_registry.h"
914
#include "core/session/abi_devices.h"
1015
#include "core/session/onnxruntime_cxx_api.h"
1116
#include "test/util/include/asserts.h"
@@ -23,6 +28,14 @@ struct ApiPtrs {
2328
const gsl::not_null<const ::OrtEpApi*> ep_api;
2429
};
2530

31+
static void CheckStringInFile(const PathString& filename, const std::string& look_for) {
32+
std::ifstream ifs{filename};
33+
std::string content(std::istreambuf_iterator<char>{ifs},
34+
std::istreambuf_iterator<char>{});
35+
36+
EXPECT_NE(content.find(look_for), std::string::npos);
37+
}
38+
2639
// Normally, a plugin EP would be implemented in a separate library.
2740
// The `test_plugin_ep` namespace contains a local implementation intended for unit testing.
2841
namespace test_plugin_ep {
@@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {
114127
return result;
115128
}
116129

130+
class MockKernelLookup : public IExecutionProvider::IKernelLookup {
131+
const KernelCreateInfo* LookUpKernel(const Node& /*node*/) const override { return nullptr; }
132+
};
133+
117134
} // namespace test_plugin_ep
118135

119136
TEST(PluginExecutionProviderTest, GetPreferredLayout) {
@@ -317,4 +334,218 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
317334
#endif // !defined(ORT_NO_EXCEPTIONS)
318335
}
319336

337+
static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path,
338+
const char* ep_name,
339+
const std::unordered_set<std::string>& ep_node_names,
340+
/*out*/ std::shared_ptr<Model>& model) {
341+
ASSERT_STATUS_OK(Model::Load(model_path, model, nullptr,
342+
DefaultLoggingManager().DefaultLogger()));
343+
344+
Graph& graph = model->MainGraph();
345+
346+
for (Node& node : graph.Nodes()) {
347+
if (ep_node_names.count(node.Name()) > 0) {
348+
node.SetExecutionProviderType(ep_name);
349+
}
350+
}
351+
}
352+
353+
static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesOneGroup(OrtEp* this_ptr, const OrtGraph* graph,
354+
OrtEpGraphSupportInfo* graph_support_info) noexcept {
355+
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);
356+
357+
size_t num_nodes = 0;
358+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
359+
return st;
360+
}
361+
362+
std::vector<const OrtNode*> nodes(num_nodes);
363+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
364+
return st;
365+
}
366+
367+
if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
368+
nodes.data(), nodes.size(), nullptr);
369+
st != nullptr) {
370+
return st;
371+
}
372+
373+
return nullptr;
374+
}
375+
376+
static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesTwoGroups(OrtEp* this_ptr, const OrtGraph* graph,
377+
OrtEpGraphSupportInfo* graph_support_info) noexcept {
378+
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);
379+
380+
size_t num_nodes = 0;
381+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
382+
return st;
383+
}
384+
385+
std::vector<const OrtNode*> nodes(num_nodes);
386+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
387+
return st;
388+
}
389+
390+
// Expect at least 2 nodes. If not, this is really a testing/setup error.
391+
if (num_nodes < 2) {
392+
return this_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL,
393+
"Expected at least two nodes in call to GetCapability");
394+
}
395+
396+
std::vector<const OrtNode*> node_group1;
397+
std::vector<const OrtNode*> node_group2;
398+
399+
for (size_t i = 0; i < num_nodes; i++) {
400+
if (i < num_nodes / 2) {
401+
node_group1.push_back(nodes[i]);
402+
} else {
403+
node_group2.push_back(nodes[i]);
404+
}
405+
}
406+
407+
if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
408+
node_group1.data(), node_group1.size(),
409+
nullptr);
410+
st != nullptr) {
411+
return st;
412+
}
413+
414+
if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
415+
node_group2.data(), node_group2.size(),
416+
nullptr);
417+
st != nullptr) {
418+
return st;
419+
}
420+
421+
return nullptr;
422+
}
423+
424+
static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode(OrtEp* this_ptr, const OrtGraph* graph,
425+
OrtEpGraphSupportInfo* graph_support_info) noexcept {
426+
auto* this_ep = static_cast<test_plugin_ep::TestOrtEp*>(this_ptr);
427+
428+
size_t num_nodes = 0;
429+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNumNodes(graph, &num_nodes); st != nullptr) {
430+
return st;
431+
}
432+
433+
std::vector<const OrtNode*> nodes(num_nodes);
434+
if (OrtStatus* st = this_ep->ort_api->Graph_GetNodes(graph, nodes.data(), nodes.size()); st != nullptr) {
435+
return st;
436+
}
437+
438+
// Take only the first node using EpGraphSupportInfo_AddSingleNode().
439+
if (OrtStatus* st = this_ep->ep_api->EpGraphSupportInfo_AddSingleNode(graph_support_info, nodes[0]);
440+
st != nullptr) {
441+
return st;
442+
}
443+
444+
return nullptr;
445+
}
446+
447+
// Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and
448+
// nodes that are already assigned to another EP.
449+
TEST(PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) {
450+
std::filesystem::path log_file = ORT_TSTR("log_get_capability.txt");
451+
452+
// Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP.
453+
// Then, IExecutionProvider::GetCapability() is called to test the expected behavior.
454+
auto run_test = [&log_file](IExecutionProvider& ep,
455+
const std::unordered_set<std::string>& nodes_for_other_ep,
456+
const std::unordered_set<std::string>& nodes_for_this_ep,
457+
const char* expected_log_string) {
458+
std::shared_ptr<Model> model;
459+
ASSERT_NO_FATAL_FAILURE(LoadModelAndAssignNodesToEp(ORT_TSTR("testdata/add_mul_add.onnx"),
460+
"OtherEp", nodes_for_other_ep, model));
461+
462+
std::filesystem::remove(log_file);
463+
464+
// Call IExecutionProvider::GetCapability and check results + logs.
465+
{
466+
logging::LoggingManager log_manager{std::make_unique<logging::FileSink>(log_file, false, false),
467+
logging::Severity::kWARNING, false,
468+
logging::LoggingManager::InstanceType::Temporal};
469+
auto file_logger = log_manager.CreateLogger("FileLogger");
470+
ep.SetLogger(file_logger.get()); // Make EP log to a file.
471+
472+
GraphViewer graph_viewer(model->MainGraph());
473+
auto compute_capabilities = ep.GetCapability(graph_viewer,
474+
test_plugin_ep::MockKernelLookup{},
475+
GraphOptimizerRegistry(nullptr, nullptr, file_logger.get()),
476+
nullptr);
477+
478+
ASSERT_EQ(compute_capabilities.size(), nodes_for_this_ep.empty() ? 0 : 1);
479+
480+
if (compute_capabilities.size() == 1) {
481+
ASSERT_EQ(compute_capabilities[0]->sub_graph->nodes.size(), nodes_for_this_ep.size());
482+
483+
for (NodeIndex node_index : compute_capabilities[0]->sub_graph->nodes) {
484+
const Node* node = graph_viewer.GetNode(node_index);
485+
ASSERT_NE(node, nullptr);
486+
EXPECT_EQ(nodes_for_this_ep.count(node->Name()), 1);
487+
}
488+
}
489+
}
490+
491+
ASSERT_TRUE(std::filesystem::exists(log_file));
492+
EXPECT_NO_FATAL_FAILURE(CheckStringInFile(log_file, expected_log_string));
493+
};
494+
495+
constexpr std::array<const char*, 3> node_names = {"add_0", "mul_0", "add_1"};
496+
497+
auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp();
498+
499+
// Load a model and assign all of its nodes to another EP named 'OtherEp'.
500+
// The plugin EP tries to claim all nodes in a single group via EpGraphSupportInfo_AddNodesToFuse.
501+
// IExecutionProvider::GetCapability() should return an empty result and log a warning.
502+
ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup;
503+
std::unordered_set<std::string> nodes_for_other_ep = {"add_0", "mul_0", "add_1"};
504+
std::unordered_set<std::string> nodes_for_this_ep;
505+
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
506+
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");
507+
508+
// Load a model and assign only one node to another EP named 'OtherEp'.
509+
// The plugin EP tries to claim all nodes in a single group.
510+
// IExecutionProvider::GetCapability() should return an empty result and log a warning.
511+
ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup;
512+
for (const char* node_name : node_names) {
513+
nodes_for_other_ep = std::unordered_set<std::string>{node_name};
514+
nodes_for_this_ep = std::unordered_set<std::string>{};
515+
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
516+
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");
517+
}
518+
519+
// Load a model and assign only the last Add node to another EP named 'OtherEp'.
520+
// The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1).
521+
// IExecutionProvider::GetCapability() will only return (add_0) because the second group has a node
522+
// that was assigned to 'OtherEp'.
523+
ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups;
524+
nodes_for_other_ep = std::unordered_set<std::string>{"add_1"};
525+
nodes_for_this_ep = std::unordered_set<std::string>{"add_0"};
526+
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
527+
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");
528+
529+
// Load a model and assign only the first Add node to another EP named 'OtherEp'.
530+
// The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1).
531+
// IExecutionProvider::GetCapability() will only return (mul_0, add_1) because the first group has a node
532+
// that was assigned to 'OtherEp'.
533+
ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups;
534+
nodes_for_other_ep = std::unordered_set<std::string>{"add_0"};
535+
nodes_for_this_ep = std::unordered_set<std::string>{"mul_0", "add_1"};
536+
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
537+
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");
538+
539+
// Load a model and assign the first Add node to another EP named 'OtherEp'.
540+
// The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode.
541+
// IExecutionProvider::GetCapability() will return an empty result and log a warning.
542+
ort_ep->GetCapability = GetCapabilityTakeSingleNode;
543+
nodes_for_other_ep = std::unordered_set<std::string>{"add_0"};
544+
nodes_for_this_ep = std::unordered_set<std::string>{};
545+
run_test(*ep, nodes_for_other_ep, nodes_for_this_ep,
546+
"Found one or more nodes that were already assigned to a different EP named 'OtherEp'");
547+
548+
std::filesystem::remove(log_file);
549+
}
550+
320551
} // namespace onnxruntime::test

0 commit comments

Comments
 (0)