Skip to content

Commit 2c50b75

Browse files
authored
Functions Ahead Of Time inlininng (#17764)
### Description Inline functions in an EP aware fashion. The result of this PR is that models that are having been inlined by ONNX inliner and optimized and models that have been AOT inlined appear to be visually identical. For tests I used two models. The only difference is the resulting size because ONNX inliner removes local function definitions and AOT does not. Difference in sizes for `HF Mobile` model was 2.5 MB, and for `HF Bart` it was ~500K. It seems that the resuling model size affects the load time more than the actual optimizations. In general, the inlined models grow in size very fast and can easily exceed 2Gb limit. Q. Should we make AOT optional? `If` costant folding and the removal of local inlined models will be coming in other PRs. Some stats: ![image](https://github.com/microsoft/onnxruntime/assets/11303988/fcb4c815-2e06-4574-8d96-5a0a727d1ecf)
1 parent f3cfe08 commit 2c50b75

11 files changed

Lines changed: 450 additions & 138 deletions

File tree

include/onnxruntime/core/graph/graph.h

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <functional>
67
#include <limits>
78
#include <memory>
89
#include <string>
@@ -83,10 +84,10 @@ class Node {
8384
gsl::span<NodeArg* const> output_args,
8485
const NodeAttributes* attributes,
8586
std::string_view domain) {
86-
Init(std::string{name}, std::string{op_type}, std::string{description},
87-
std::vector<NodeArg*>{input_args.begin(), input_args.end()},
88-
std::vector<NodeArg*>{output_args.begin(), output_args.end()},
89-
attributes, std::string{domain});
87+
Init(name, op_type, description,
88+
input_args,
89+
output_args,
90+
attributes, domain);
9091
}
9192
#endif
9293

@@ -563,13 +564,13 @@ class Node {
563564
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);
564565

565566
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
566-
void Init(const std::string& name,
567-
const std::string& op_type,
568-
const std::string& description,
569-
const std::vector<NodeArg*>& input_args,
570-
const std::vector<NodeArg*>& output_args,
567+
void Init(std::string_view name,
568+
std::string_view op_type,
569+
std::string_view description,
570+
gsl::span<NodeArg* const> input_args,
571+
gsl::span<NodeArg* const> output_args,
571572
const NodeAttributes* attributes,
572-
const std::string& domain);
573+
std::string_view domain);
573574
#endif
574575

575576
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@@ -1141,8 +1142,22 @@ class Graph {
11411142
*/
11421143
Status InlineFunction(Node& node);
11431144

1145+
/**
1146+
Directly insert the nodes in the function proto provided into the graph.
1147+
The function converts Constant nodes into the initializers in the graph.
1148+
It then creates a node in the graph for each of the function nodes.
1149+
All of the names are expected to be specialized, and, therefore unique.
1150+
See function_utils::Specialize().
1151+
1152+
The Graph needs to be Resolve()d after this call.
1153+
@param func_to_inline
1154+
@returns Status indicating success or providing an error message.
1155+
*/
1156+
1157+
Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
1158+
11441159
/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
1145-
be used as a GraphProto attribute in another Node..
1160+
be used as a GraphProto attribute in another Node.
11461161
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
11471162
define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
11481163
when the Graph is resolved.
@@ -1391,6 +1406,13 @@ class Graph {
13911406
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
13921407
const ArgNameToTypeMap& name_to_type);
13931408

1409+
/** Helper that converts and adds constant node proto to an initializer in the graph.
1410+
@param constant_node_proto Constant node to convert
1411+
@param new_name use the new name for the initializer.
1412+
*/
1413+
Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
1414+
std::optional<std::string_view> new_name);
1415+
13941416
#endif
13951417

13961418
Version IrVersion() const noexcept {

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
6767
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
6868
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
6969

70+
// This setting controls whether to enable AheadOfTime function inlining.
71+
// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
72+
// as possible with the help of enabled execution providers.
73+
// This can reduce the number of function calls and improve performance because it is done before
74+
// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
75+
// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
76+
// "0": enable; "1": disable.
77+
// Its default value is "0".
78+
static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
79+
7080
#ifdef ENABLE_TRAINING
7181
// Specifies a list of op types for memory footprint reduction.
7282
// The value should be a ","-delimited list of pair of

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 146 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
#include "core/framework/kernel_registry_manager.h"
1414
#include "core/framework/kernel_registry.h"
1515
#include "core/graph/function.h"
16+
#include "core/graph/function_utils.h"
1617
#include "core/graph/graph_viewer.h"
18+
#include "core/graph/model.h"
1719

1820
// uncomment this line to count non-CUDA ops in ONNX domain
1921
// #define COUNT_NON_CUDA_OPS
@@ -129,6 +131,21 @@ struct GetCapabilityForEPParams {
129131
std::reference_wrapper<const layout_transformation::DebugGraphFn> debug_graph_fn;
130132
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
131133
};
134+
135+
auto get_capabilities = [](const IExecutionProvider& ep,
136+
const GraphViewer& graph_viewer,
137+
const IExecutionProvider::IKernelLookup& kernel_lookup) {
138+
auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup);
139+
140+
// In theory an EP could return an empty capability. Remove those.
141+
capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(),
142+
[](const std::unique_ptr<ComputeCapability>& capability) {
143+
return !capability || !capability->sub_graph;
144+
}),
145+
capabilities.end());
146+
147+
return capabilities;
148+
};
132149
} // namespace
133150

134151
static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
@@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
143160
}
144161
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
145162

146-
auto get_capabilities = [](const IExecutionProvider& ep,
147-
const GraphViewer& graph_viewer,
148-
const IExecutionProvider::IKernelLookup& kernel_lookup) {
149-
auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup);
150-
151-
// In theory an EP could return an empty capability. Remove those.
152-
capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(),
153-
[](const std::unique_ptr<ComputeCapability>& capability) {
154-
return !capability || !capability->sub_graph;
155-
}),
156-
capabilities.end());
157-
158-
return capabilities;
159-
};
160-
161163
const auto& kernel_registry_mgr = params.kernel_registry_mgr.get();
162164
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
163165
const KernelLookup kernel_lookup{ep_type,
@@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
239241
}
240242

241243
#if !defined(ORT_MINIMAL_BUILD)
244+
245+
// This function queries the capabilities for a given EP, but it does not assign the nodes.
246+
// It also does not perform layout transformation. This will be done during normal partitioning.
247+
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
248+
const KernelRegistryManager& kernel_registry_mgr,
249+
const IExecutionProvider& current_ep,
250+
std::vector<std::unique_ptr<ComputeCapability>>& capabilities) {
251+
const auto& ep_type = current_ep.Type();
252+
253+
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
254+
const KernelLookup kernel_lookup{ep_type,
255+
kernel_registries_for_ep,
256+
kernel_registry_mgr.GetKernelTypeStrResolver()};
257+
258+
// TODO: Provide EP with a capability to look inside the functions.
259+
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
260+
261+
return Status::OK();
262+
}
263+
242264
/**
243265
* Check if a node can be placed on a specific provider.
244266
* Do nothing if the node is already assigned
@@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
518540
// successfully inlined, we re-run the partitioner on the modified graph.
519541
// NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating
520542
// using graph.Nodes().
521-
std::vector<Node*> nodes_to_inline;
543+
InlinedVector<Node*> nodes_to_inline;
522544
for (auto& node : graph.Nodes()) {
523545
if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) {
524546
nodes_to_inline.push_back(&node);
@@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
533555
return Status::OK();
534556
}
535557

558+
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
559+
const KernelRegistryManager& kernel_registry_mgr,
560+
Graph& graph,
561+
InlinedHashSet<std::string>& not_inlined,
562+
size_t& inlined_count) {
563+
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
564+
// doing it here saves all providers checking for this in GetCapability
565+
if (graph.NumberOfNodes() == 0) {
566+
return Status::OK();
567+
}
568+
569+
for (auto& node : graph.Nodes()) {
570+
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
571+
Graph* subgraph = entry.second;
572+
// we pass through the FuncManager from the top level graph
573+
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
574+
kernel_registry_mgr,
575+
*subgraph,
576+
not_inlined,
577+
inlined_count));
578+
}
579+
}
580+
581+
// Gather the candidates
582+
InlinedVector<NodeIndex> inline_candidates;
583+
for (auto& node : graph.Nodes()) {
584+
if (node.CanBeInlined()) {
585+
inline_candidates.push_back(node.Index());
586+
}
587+
}
588+
589+
if (inline_candidates.empty()) {
590+
return Status::OK();
591+
}
592+
593+
// Find out all the nodes that are already taken
594+
const GraphViewer graph_viewer(graph);
595+
596+
InlinedHashSet<NodeIndex> claimed_by_ep;
597+
for (const auto& ep : execution_providers) {
598+
std::vector<std::unique_ptr<ComputeCapability>> capabilities;
599+
ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
600+
for (auto& capability : capabilities) {
601+
const auto& nodes = capability->sub_graph->nodes;
602+
if (nodes.size() == 1) {
603+
// Single node capability.
604+
ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0]));
605+
} else {
606+
// Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl.
607+
if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) {
608+
return claimed_by_ep.count(node_index) == 0;
609+
})) {
610+
claimed_by_ep.insert(nodes.cbegin(), nodes.cend());
611+
}
612+
}
613+
}
614+
}
615+
616+
// TODO: Insert version check. We need to collect all the versions
617+
// that imported by the model. If the version is not supported by
618+
// the model, we can not inline it.
619+
620+
for (auto node_index : inline_candidates) {
621+
auto* node = graph.GetNode(node_index);
622+
if (node != nullptr) {
623+
if (claimed_by_ep.count(node_index) == 0) {
624+
ORT_RETURN_IF_ERROR(graph.InlineFunction(*node));
625+
++inlined_count;
626+
} else {
627+
// OpType is the same as function name.
628+
auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType());
629+
ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id)));
630+
}
631+
}
632+
}
633+
634+
return Status::OK();
635+
}
636+
536637
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
537638
const ExecutionProviders& execution_providers,
538639
KernelRegistryManager& kernel_registry_manager) {
@@ -693,6 +794,35 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
693794
return Status::OK();
694795
}
695796

797+
#ifndef ORT_MINIMAL_BUILD
798+
799+
Status GraphPartitioner::InlineFunctionsAOT(Model& model,
800+
const ExecutionProviders& execution_providers,
801+
const KernelRegistryManager& kernel_registry_manager) const {
802+
auto& graph = model.MainGraph();
803+
InlinedHashSet<std::string> not_inlined;
804+
do {
805+
size_t inlined_count = 0;
806+
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
807+
kernel_registry_manager,
808+
graph,
809+
not_inlined,
810+
inlined_count));
811+
812+
if (inlined_count == 0) {
813+
break;
814+
}
815+
816+
ORT_RETURN_IF_ERROR(graph.Resolve());
817+
} while (true);
818+
819+
model.RemoveLocalFunctionsProtos(not_inlined);
820+
821+
return Status::OK();
822+
}
823+
824+
#endif
825+
696826
Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
697827
const layout_transformation::TransformLayoutFunction& transform_layout_function,
698828
Mode mode,

onnxruntime/core/framework/graph_partitioner.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace onnxruntime {
1212

1313
class ExecutionProviders;
1414
class KernelRegistryManager;
15+
class Model;
1516

1617
class GraphPartitioner {
1718
public:
@@ -33,6 +34,26 @@ class GraphPartitioner {
3334
Mode mode = Mode::kNormal,
3435
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
3536

37+
#ifndef ORT_MINIMAL_BUILD
38+
/// <summary>
39+
// Ahead of Time Function inlining. The main purpose of the function is to inline as many
40+
// functions as possible and delete locally defined functions to reduce the size of the model.
41+
// This would make other optimizations to be more effective.
42+
//
43+
// This function performs GetCapability on the graph and its subgraphs bottom up
44+
// and inlines any functions that are not claimed by any of the execution providers.
45+
// This function does not attempt to run layout transformation, and it does not assign EPs.
46+
// The latter will be done by graph partitioning after Level1 optimizations are done.
47+
/// </summary>
48+
/// <param name="model">model instance</param>
49+
/// <param name="execution_providers">execution providers considered</param>
50+
/// <param name="kernel_registry_manager">registry manager</param>
51+
/// <returns></returns>
52+
Status InlineFunctionsAOT(Model& model,
53+
const ExecutionProviders& execution_providers,
54+
const KernelRegistryManager& kernel_registry_manager) const;
55+
#endif
56+
3657
private:
3758
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner);
3859

onnxruntime/core/graph/function_utils.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,8 @@ class Inliner {
373373
// Replace given name with a unique version of the name, and cache the
374374
// renaming-binding in current scope.
375375
void make_unique(std::string& name) {
376-
auto new_name = prefix_ + name;
376+
auto new_name{prefix_};
377+
new_name.append("_").append(name);
377378
auto& current_scope = rename_scopes_.back();
378379
current_scope[name] = new_name;
379380
name = std::move(new_name);
@@ -410,7 +411,7 @@ class Inliner {
410411
std::string rename_as = actuals.Get(i);
411412
if constexpr (isOutput) {
412413
if (rename_as.empty())
413-
rename_as.assign(prefix_).append(formal);
414+
rename_as.assign(prefix_).append("_").append(formal);
414415
}
415416
current_scope[formal] = rename_as;
416417
if (!rename_as.empty())
@@ -420,7 +421,7 @@ class Inliner {
420421
std::string& formal = *formals.Mutable(i);
421422
std::string rename_as;
422423
if constexpr (isOutput) {
423-
rename_as.assign(prefix_).append(formal);
424+
rename_as.assign(prefix_).append("_").append(formal);
424425
}
425426
current_scope[formal] = rename_as;
426427
if (!rename_as.empty())

0 commit comments

Comments
 (0)