33
44#include " core/session/plugin_ep/ep_plugin_provider_interfaces.h"
55
6+ #include < filesystem>
67#include " gsl/gsl"
78#include " gtest/gtest.h"
89
10+ #include " core/common/logging/sinks/file_sink.h"
11+ #include " core/graph/graph_viewer.h"
12+ #include " core/graph/model.h"
13+ #include " core/optimizer/graph_optimizer_registry.h"
914#include " core/session/abi_devices.h"
1015#include " core/session/onnxruntime_cxx_api.h"
1116#include " test/util/include/asserts.h"
@@ -23,6 +28,14 @@ struct ApiPtrs {
2328 const gsl::not_null<const ::OrtEpApi*> ep_api;
2429};
2530
31+ static void CheckStringInFile (const PathString& filename, const std::string& look_for) {
32+ std::ifstream ifs{filename};
33+ std::string content (std::istreambuf_iterator<char >{ifs},
34+ std::istreambuf_iterator<char >{});
35+
36+ EXPECT_NE (content.find (look_for), std::string::npos);
37+ }
38+
2639// Normally, a plugin EP would be implemented in a separate library.
2740// The `test_plugin_ep` namespace contains a local implementation intended for unit testing.
2841namespace test_plugin_ep {
@@ -114,6 +127,10 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector<const OrtEpDevice*> ep_devices = {
114127 return result;
115128}
116129
130+ class MockKernelLookup : public IExecutionProvider ::IKernelLookup {
131+ const KernelCreateInfo* LookUpKernel (const Node& /* node*/ ) const override { return nullptr ; }
132+ };
133+
117134} // namespace test_plugin_ep
118135
119136TEST (PluginExecutionProviderTest, GetPreferredLayout) {
@@ -317,4 +334,218 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) {
317334#endif // !defined(ORT_NO_EXCEPTIONS)
318335}
319336
337+ static void LoadModelAndAssignNodesToEp (const ORTCHAR_T* model_path,
338+ const char * ep_name,
339+ const std::unordered_set<std::string>& ep_node_names,
340+ /* out*/ std::shared_ptr<Model>& model) {
341+ ASSERT_STATUS_OK (Model::Load (model_path, model, nullptr ,
342+ DefaultLoggingManager ().DefaultLogger ()));
343+
344+ Graph& graph = model->MainGraph ();
345+
346+ for (Node& node : graph.Nodes ()) {
347+ if (ep_node_names.count (node.Name ()) > 0 ) {
348+ node.SetExecutionProviderType (ep_name);
349+ }
350+ }
351+ }
352+
353+ static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesOneGroup (OrtEp* this_ptr, const OrtGraph* graph,
354+ OrtEpGraphSupportInfo* graph_support_info) noexcept {
355+ auto * this_ep = static_cast <test_plugin_ep::TestOrtEp*>(this_ptr);
356+
357+ size_t num_nodes = 0 ;
358+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNumNodes (graph, &num_nodes); st != nullptr ) {
359+ return st;
360+ }
361+
362+ std::vector<const OrtNode*> nodes (num_nodes);
363+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNodes (graph, nodes.data (), nodes.size ()); st != nullptr ) {
364+ return st;
365+ }
366+
367+ if (OrtStatus* st = this_ep->ep_api ->EpGraphSupportInfo_AddNodesToFuse (graph_support_info,
368+ nodes.data (), nodes.size (), nullptr );
369+ st != nullptr ) {
370+ return st;
371+ }
372+
373+ return nullptr ;
374+ }
375+
376+ static OrtStatus* ORT_API_CALL GetCapabilityTakeAllNodesTwoGroups (OrtEp* this_ptr, const OrtGraph* graph,
377+ OrtEpGraphSupportInfo* graph_support_info) noexcept {
378+ auto * this_ep = static_cast <test_plugin_ep::TestOrtEp*>(this_ptr);
379+
380+ size_t num_nodes = 0 ;
381+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNumNodes (graph, &num_nodes); st != nullptr ) {
382+ return st;
383+ }
384+
385+ std::vector<const OrtNode*> nodes (num_nodes);
386+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNodes (graph, nodes.data (), nodes.size ()); st != nullptr ) {
387+ return st;
388+ }
389+
390+ // Expect at least 2 nodes. If not, this is really a testing/setup error.
391+ if (num_nodes < 2 ) {
392+ return this_ep->ort_api ->CreateStatus (OrtErrorCode::ORT_FAIL,
393+ " Expected at least two nodes in call to GetCapability" );
394+ }
395+
396+ std::vector<const OrtNode*> node_group1;
397+ std::vector<const OrtNode*> node_group2;
398+
399+ for (size_t i = 0 ; i < num_nodes; i++) {
400+ if (i < num_nodes / 2 ) {
401+ node_group1.push_back (nodes[i]);
402+ } else {
403+ node_group2.push_back (nodes[i]);
404+ }
405+ }
406+
407+ if (OrtStatus* st = this_ep->ep_api ->EpGraphSupportInfo_AddNodesToFuse (graph_support_info,
408+ node_group1.data (), node_group1.size (),
409+ nullptr );
410+ st != nullptr ) {
411+ return st;
412+ }
413+
414+ if (OrtStatus* st = this_ep->ep_api ->EpGraphSupportInfo_AddNodesToFuse (graph_support_info,
415+ node_group2.data (), node_group2.size (),
416+ nullptr );
417+ st != nullptr ) {
418+ return st;
419+ }
420+
421+ return nullptr ;
422+ }
423+
424+ static OrtStatus* ORT_API_CALL GetCapabilityTakeSingleNode (OrtEp* this_ptr, const OrtGraph* graph,
425+ OrtEpGraphSupportInfo* graph_support_info) noexcept {
426+ auto * this_ep = static_cast <test_plugin_ep::TestOrtEp*>(this_ptr);
427+
428+ size_t num_nodes = 0 ;
429+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNumNodes (graph, &num_nodes); st != nullptr ) {
430+ return st;
431+ }
432+
433+ std::vector<const OrtNode*> nodes (num_nodes);
434+ if (OrtStatus* st = this_ep->ort_api ->Graph_GetNodes (graph, nodes.data (), nodes.size ()); st != nullptr ) {
435+ return st;
436+ }
437+
438+ // Take only the first node using EpGraphSupportInfo_AddSingleNode().
439+ if (OrtStatus* st = this_ep->ep_api ->EpGraphSupportInfo_AddSingleNode (graph_support_info, nodes[0 ]);
440+ st != nullptr ) {
441+ return st;
442+ }
443+
444+ return nullptr ;
445+ }
446+
447+ // Tests that GetCapability() doesn't crash if a plugin EP tries to claim a mix of unassigned nodes and
448+ // nodes that are already assigned to another EP.
449+ TEST (PluginExecutionProviderTest, GetCapability_ClaimNodesAssignedToOtherEP) {
450+ std::filesystem::path log_file = ORT_TSTR (" log_get_capability.txt" );
451+
452+ // Helper function that loads a model (Add -> Mul -> Add) and assigns some or all of the nodes to another EP.
453+ // Then, IExecutionProvider::GetCapability() is called to test the expected behavior.
454+ auto run_test = [&log_file](IExecutionProvider& ep,
455+ const std::unordered_set<std::string>& nodes_for_other_ep,
456+ const std::unordered_set<std::string>& nodes_for_this_ep,
457+ const char * expected_log_string) {
458+ std::shared_ptr<Model> model;
459+ ASSERT_NO_FATAL_FAILURE (LoadModelAndAssignNodesToEp (ORT_TSTR (" testdata/add_mul_add.onnx" ),
460+ " OtherEp" , nodes_for_other_ep, model));
461+
462+ std::filesystem::remove (log_file);
463+
464+ // Call IExecutionProvider::GetCapability and check results + logs.
465+ {
466+ logging::LoggingManager log_manager{std::make_unique<logging::FileSink>(log_file, false , false ),
467+ logging::Severity::kWARNING , false ,
468+ logging::LoggingManager::InstanceType::Temporal};
469+ auto file_logger = log_manager.CreateLogger (" FileLogger" );
470+ ep.SetLogger (file_logger.get ()); // Make EP log to a file.
471+
472+ GraphViewer graph_viewer (model->MainGraph ());
473+ auto compute_capabilities = ep.GetCapability (graph_viewer,
474+ test_plugin_ep::MockKernelLookup{},
475+ GraphOptimizerRegistry (nullptr , nullptr , file_logger.get ()),
476+ nullptr );
477+
478+ ASSERT_EQ (compute_capabilities.size (), nodes_for_this_ep.empty () ? 0 : 1 );
479+
480+ if (compute_capabilities.size () == 1 ) {
481+ ASSERT_EQ (compute_capabilities[0 ]->sub_graph ->nodes .size (), nodes_for_this_ep.size ());
482+
483+ for (NodeIndex node_index : compute_capabilities[0 ]->sub_graph ->nodes ) {
484+ const Node* node = graph_viewer.GetNode (node_index);
485+ ASSERT_NE (node, nullptr );
486+ EXPECT_EQ (nodes_for_this_ep.count (node->Name ()), 1 );
487+ }
488+ }
489+ }
490+
491+ ASSERT_TRUE (std::filesystem::exists (log_file));
492+ EXPECT_NO_FATAL_FAILURE (CheckStringInFile (log_file, expected_log_string));
493+ };
494+
495+ constexpr std::array<const char *, 3 > node_names = {" add_0" , " mul_0" , " add_1" };
496+
497+ auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp ();
498+
499+ // Load a model and assign all of its nodes to another EP named 'OtherEp'.
500+ // The plugin EP tries to claim all nodes in a single group via EpGraphSupportInfo_AddNodesToFuse.
501+ // IExecutionProvider::GetCapability() should return an empty result and log a warning.
502+ ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup;
503+ std::unordered_set<std::string> nodes_for_other_ep = {" add_0" , " mul_0" , " add_1" };
504+ std::unordered_set<std::string> nodes_for_this_ep;
505+ run_test (*ep, nodes_for_other_ep, nodes_for_this_ep,
506+ " Found one or more nodes that were already assigned to a different EP named 'OtherEp'" );
507+
508+ // Load a model and assign only one node to another EP named 'OtherEp'.
509+ // The plugin EP tries to claim all nodes in a single group.
510+ // IExecutionProvider::GetCapability() should return an empty result and log a warning.
511+ ort_ep->GetCapability = GetCapabilityTakeAllNodesOneGroup;
512+ for (const char * node_name : node_names) {
513+ nodes_for_other_ep = std::unordered_set<std::string>{node_name};
514+ nodes_for_this_ep = std::unordered_set<std::string>{};
515+ run_test (*ep, nodes_for_other_ep, nodes_for_this_ep,
516+ " Found one or more nodes that were already assigned to a different EP named 'OtherEp'" );
517+ }
518+
519+ // Load a model and assign only the last Add node to another EP named 'OtherEp'.
520+ // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1).
521+ // IExecutionProvider::GetCapability() will only return (add_0) because the second group has a node
522+ // that was assigned to 'OtherEp'.
523+ ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups;
524+ nodes_for_other_ep = std::unordered_set<std::string>{" add_1" };
525+ nodes_for_this_ep = std::unordered_set<std::string>{" add_0" };
526+ run_test (*ep, nodes_for_other_ep, nodes_for_this_ep,
527+ " Found one or more nodes that were already assigned to a different EP named 'OtherEp'" );
528+
529+ // Load a model and assign only the first Add node to another EP named 'OtherEp'.
530+ // The plugin EP tries to claim all nodes in the following 2 groups: (add_0), (mul_0, add_1).
531+ // IExecutionProvider::GetCapability() will only return (mul_0, add_1) because the first group has a node
532+ // that was assigned to 'OtherEp'.
533+ ort_ep->GetCapability = GetCapabilityTakeAllNodesTwoGroups;
534+ nodes_for_other_ep = std::unordered_set<std::string>{" add_0" };
535+ nodes_for_this_ep = std::unordered_set<std::string>{" mul_0" , " add_1" };
536+ run_test (*ep, nodes_for_other_ep, nodes_for_this_ep,
537+ " Found one or more nodes that were already assigned to a different EP named 'OtherEp'" );
538+
539+ // Load a model and assign the first Add node to another EP named 'OtherEp'.
540+ // The plugin EP will try to take only the first Add node with a single call to EpGraphSupportInfo_AddSingleNode.
541+ // IExecutionProvider::GetCapability() will return an empty result and log a warning.
542+ ort_ep->GetCapability = GetCapabilityTakeSingleNode;
543+ nodes_for_other_ep = std::unordered_set<std::string>{" add_0" };
544+ nodes_for_this_ep = std::unordered_set<std::string>{};
545+ run_test (*ep, nodes_for_other_ep, nodes_for_this_ep,
546+ " Found one or more nodes that were already assigned to a different EP named 'OtherEp'" );
547+
548+ std::filesystem::remove (log_file);
549+ }
550+
320551} // namespace onnxruntime::test
0 commit comments