1313#include " core/framework/kernel_registry_manager.h"
1414#include " core/framework/kernel_registry.h"
1515#include " core/graph/function.h"
16+ #include " core/graph/function_utils.h"
1617#include " core/graph/graph_viewer.h"
18+ #include " core/graph/model.h"
1719
1820// uncomment this line to count non-CUDA ops in ONNX domain
1921// #define COUNT_NON_CUDA_OPS
@@ -129,6 +131,21 @@ struct GetCapabilityForEPParams {
129131 std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
130132#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
131133};
134+
135+ auto get_capabilities = [](const IExecutionProvider& ep,
136+ const GraphViewer& graph_viewer,
137+ const IExecutionProvider::IKernelLookup& kernel_lookup) {
138+ auto capabilities = ep.GetCapability (graph_viewer, kernel_lookup);
139+
140+ // In theory an EP could return an empty capability. Remove those.
141+ capabilities.erase (std::remove_if (capabilities.begin (), capabilities.end (),
142+ [](const std::unique_ptr<ComputeCapability>& capability) {
143+ return !capability || !capability->sub_graph ;
144+ }),
145+ capabilities.end ());
146+
147+ return capabilities;
148+ };
132149} // namespace
133150
134151static Status GetCapabilityForEP (const GetCapabilityForEPParams& params) {
@@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
143160 }
144161#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
145162
146- auto get_capabilities = [](const IExecutionProvider& ep,
147- const GraphViewer& graph_viewer,
148- const IExecutionProvider::IKernelLookup& kernel_lookup) {
149- auto capabilities = ep.GetCapability (graph_viewer, kernel_lookup);
150-
151- // In theory an EP could return an empty capability. Remove those.
152- capabilities.erase (std::remove_if (capabilities.begin (), capabilities.end (),
153- [](const std::unique_ptr<ComputeCapability>& capability) {
154- return !capability || !capability->sub_graph ;
155- }),
156- capabilities.end ());
157-
158- return capabilities;
159- };
160-
161163 const auto & kernel_registry_mgr = params.kernel_registry_mgr .get ();
162164 const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType (ep_type);
163165 const KernelLookup kernel_lookup{ep_type,
@@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
239241}
240242
241243#if !defined(ORT_MINIMAL_BUILD)
244+
245+ // This function queries the capabilities for a given EP, but it does not assign the nodes.
246+ // It also does not perform layout transformation. This will be done during normal partitioning.
247+ static Status GetCapabilityForEPForAotInlining (const GraphViewer& graph_viewer,
248+ const KernelRegistryManager& kernel_registry_mgr,
249+ const IExecutionProvider& current_ep,
250+ std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
251+ const auto & ep_type = current_ep.Type ();
252+
253+ const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType (ep_type);
254+ const KernelLookup kernel_lookup{ep_type,
255+ kernel_registries_for_ep,
256+ kernel_registry_mgr.GetKernelTypeStrResolver ()};
257+
258+ // TODO: Provide EP with a capability to look inside the functions.
259+ capabilities = get_capabilities (current_ep, graph_viewer, kernel_lookup);
260+
261+ return Status::OK ();
262+ }
263+
242264/* *
243265 * Check if a node can be placed on a specific provider.
244266 * Do nothing if the node is already assigned
@@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
518540 // successfully inlined, we re-run the partitioner on the modified graph.
519541 // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating
520542 // using graph.Nodes().
521- std::vector <Node*> nodes_to_inline;
543+ InlinedVector <Node*> nodes_to_inline;
522544 for (auto & node : graph.Nodes ()) {
523545 if (node.GetExecutionProviderType ().empty () && node.CanBeInlined ()) {
524546 nodes_to_inline.push_back (&node);
@@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
533555 return Status::OK ();
534556}
535557
558+ static Status InlineFunctionsAOTImpl (const ExecutionProviders& execution_providers,
559+ const KernelRegistryManager& kernel_registry_mgr,
560+ Graph& graph,
561+ InlinedHashSet<std::string>& not_inlined,
562+ size_t & inlined_count) {
563+ // handle testing edge case where optimizers or constant lifting results in graph with no nodes.
564+ // doing it here saves all providers checking for this in GetCapability
565+ if (graph.NumberOfNodes () == 0 ) {
566+ return Status::OK ();
567+ }
568+
569+ for (auto & node : graph.Nodes ()) {
570+ for (auto & entry : node.GetAttributeNameToMutableSubgraphMap ()) {
571+ Graph* subgraph = entry.second ;
572+ // we pass through the FuncManager from the top level graph
573+ ORT_RETURN_IF_ERROR (InlineFunctionsAOTImpl (execution_providers,
574+ kernel_registry_mgr,
575+ *subgraph,
576+ not_inlined,
577+ inlined_count));
578+ }
579+ }
580+
581+ // Gather the candidates
582+ InlinedVector<NodeIndex> inline_candidates;
583+ for (auto & node : graph.Nodes ()) {
584+ if (node.CanBeInlined ()) {
585+ inline_candidates.push_back (node.Index ());
586+ }
587+ }
588+
589+ if (inline_candidates.empty ()) {
590+ return Status::OK ();
591+ }
592+
593+ // Find out all the nodes that are already taken
594+ const GraphViewer graph_viewer (graph);
595+
596+ InlinedHashSet<NodeIndex> claimed_by_ep;
597+ for (const auto & ep : execution_providers) {
598+ std::vector<std::unique_ptr<ComputeCapability>> capabilities;
599+ ORT_RETURN_IF_ERROR (GetCapabilityForEPForAotInlining (graph_viewer, kernel_registry_mgr, *ep, capabilities));
600+ for (auto & capability : capabilities) {
601+ const auto & nodes = capability->sub_graph ->nodes ;
602+ if (nodes.size () == 1 ) {
603+ // Single node capability.
604+ ORT_IGNORE_RETURN_VALUE (claimed_by_ep.insert (nodes[0 ]));
605+ } else {
606+ // Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl.
607+ if (std::all_of (nodes.cbegin (), nodes.cend (), [&claimed_by_ep](NodeIndex node_index) {
608+ return claimed_by_ep.count (node_index) == 0 ;
609+ })) {
610+ claimed_by_ep.insert (nodes.cbegin (), nodes.cend ());
611+ }
612+ }
613+ }
614+ }
615+
616+ // TODO: Insert version check. We need to collect all the versions
617+ // that imported by the model. If the version is not supported by
618+ // the model, we can not inline it.
619+
620+ for (auto node_index : inline_candidates) {
621+ auto * node = graph.GetNode (node_index);
622+ if (node != nullptr ) {
623+ if (claimed_by_ep.count (node_index) == 0 ) {
624+ ORT_RETURN_IF_ERROR (graph.InlineFunction (*node));
625+ ++inlined_count;
626+ } else {
627+ // OpType is the same as function name.
628+ auto function_id = function_utils::GetFunctionIdentifier (node->Domain (), node->OpType ());
629+ ORT_IGNORE_RETURN_VALUE (not_inlined.insert (std::move (function_id)));
630+ }
631+ }
632+ }
633+
634+ return Status::OK ();
635+ }
636+
536637static Status PartitionOnnxFormatModel (const PartitionParams& partition_params, GraphPartitioner::Mode mode,
537638 const ExecutionProviders& execution_providers,
538639 KernelRegistryManager& kernel_registry_manager) {
@@ -693,6 +794,35 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
693794 return Status::OK ();
694795}
695796
797+ #ifndef ORT_MINIMAL_BUILD
798+
799+ Status GraphPartitioner::InlineFunctionsAOT (Model& model,
800+ const ExecutionProviders& execution_providers,
801+ const KernelRegistryManager& kernel_registry_manager) const {
802+ auto & graph = model.MainGraph ();
803+ InlinedHashSet<std::string> not_inlined;
804+ do {
805+ size_t inlined_count = 0 ;
806+ ORT_RETURN_IF_ERROR (InlineFunctionsAOTImpl (execution_providers,
807+ kernel_registry_manager,
808+ graph,
809+ not_inlined,
810+ inlined_count));
811+
812+ if (inlined_count == 0 ) {
813+ break ;
814+ }
815+
816+ ORT_RETURN_IF_ERROR (graph.Resolve ());
817+ } while (true );
818+
819+ model.RemoveLocalFunctionsProtos (not_inlined);
820+
821+ return Status::OK ();
822+ }
823+
824+ #endif
825+
696826Status GraphPartitioner::Partition (Graph& graph, FuncManager& func_mgr,
697827 const layout_transformation::TransformLayoutFunction& transform_layout_function,
698828 Mode mode,
0 commit comments