Skip to content

Commit bcc7b9c

Browse files
[DML EP] Add dynamic graph compilation (#17876)
Historically, DML was only able to fuse partitions when all sizes are known in advance or when we were overriding them at session creation time. But in practice, it should be possible to compile partitions at compute time if the caller knows that the dimensions won't be changed for every inference (e.g. resizing a webcam window, or padding the input to powers of 2). This graph will be cached and reused until the sizes change. This is an opt-in option gated under the `enable_dynamic_graph_fusion` option, which means that it will only be enabled when the caller requests it since they have more context on how their model will be called between inferences. This PR also adds the option to disable metacommands from the python API, which is an option for the C API but was lacking for python.
1 parent 749bcc7 commit bcc7b9c

26 files changed

Lines changed: 1126 additions & 139 deletions

onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
#pragma once
55
interface IMLOperatorRegistry;
6+
interface IDMLDevice;
7+
interface ID3D12CommandQueue;
8+
interface ID3D12Resource;
69

710
#include "core/common/status.h"
811
#include "core/framework/data_transfer.h"
@@ -28,7 +31,8 @@ namespace Dml
2831
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
2932
IDMLDevice* dmlDevice,
3033
ID3D12CommandQueue* commandQueue,
31-
bool enableMetacommands = true);
34+
bool enableMetacommands,
35+
bool enableDynamicGraphFusion);
3236

3337
ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
3438
void FlushContext(onnxruntime::IExecutionProvider* provider);

onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
#include <functional>
88
#include <variant>
99
#include <optional>
10+
#include <wrl/client.h>
1011

1112
#include "core/framework/op_kernel.h"
13+
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"
1214

1315
struct AbstractOperatorDesc;
1416
interface IMLOperatorTensor;
17+
interface IDMLOperator;
1518
struct DML_INPUT_GRAPH_EDGE_DESC;
1619
struct DML_OUTPUT_GRAPH_EDGE_DESC;
1720
struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
@@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter
9295
const onnxruntime::Node& node,
9396
MLOperatorTensorGetter& constantInputGetter,
9497
const void* executionHandle,
98+
const EdgeShapes* inputShapesOverrides,
99+
/*out*/ EdgeShapes* outputShapes,
95100
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
96101
)>;
97102

onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -491,22 +491,24 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
491491
const onnxruntime::Node& node,
492492
MLOperatorTensorGetter& constantInputGetter,
493493
const void* executionHandle,
494+
const EdgeShapes* inputShapesOverrides,
495+
/*out*/ EdgeShapes* outputShapes,
494496
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
495497
)
496498
{
497499
onnxruntime::ProtoHelperNodeContext nodeContext(node);
498500
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);
499501

500502
// Use the same list of required constant inputs for the shape inferrer and the kernel.
501-
EdgeShapes outputShapes;
502-
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
503+
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);
503504

504505
// Create the kernel while allowing input shape and output shape queries according to options
505506
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
506507
&protoHelper,
507508
executionHandle,
508509
true,
509-
&outputShapes,
510+
inputShapesOverrides,
511+
outputShapes,
510512
&defaultAttributesCapture,
511513
graphNodeCreateInfo,
512514
constantCpuInputCapture,
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
namespace Windows::AI::MachineLearning::Adapter
7+
{
8+
// edges and unused edges have an empty array of dimensions.
9+
class EdgeShapes
10+
{
11+
public:
12+
EdgeShapes() = default;
13+
14+
EdgeShapes(size_t count) : m_shapes(count) {}
15+
16+
const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
17+
{
18+
return m_shapes[edgeIndex];
19+
}
20+
21+
std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
22+
{
23+
return m_shapes[edgeIndex];
24+
}
25+
26+
size_t EdgeCount() const { return m_shapes.size(); }
27+
28+
void Reset(size_t edge_count)
29+
{
30+
m_shapes.clear();
31+
m_shapes.resize(edge_count);
32+
}
33+
34+
bool operator!=(const EdgeShapes& other) const noexcept
35+
{
36+
return (m_shapes != other.m_shapes);
37+
}
38+
39+
private:
40+
std::vector<std::vector<uint32_t>> m_shapes;
41+
};
42+
}

onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp

Lines changed: 167 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

33
#include "DmlGraphFusionHelper.h"
4-
4+
#include "DmlRuntimeFusedGraphKernel.h"
55

66
namespace Dml
77
{
@@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper
501501

502502
graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode);
503503
}
504+
505+
void RegisterDynamicKernel(
506+
onnxruntime::Graph& graph,
507+
onnxruntime::KernelRegistry* registryForPartitionKernels,
508+
const ExecutionProviderImpl* providerImpl,
509+
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
510+
const std::unordered_set<std::string>& dynamicCpuInputMap,
511+
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
512+
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable)
513+
{
514+
struct NodeInfo
515+
{
516+
std::string name;
517+
std::string opType;
518+
std::string description;
519+
std::string domain;
520+
onnxruntime::NodeAttributes attributes;
521+
std::vector<onnxruntime::NodeArg*> inputDefPointers;
522+
std::vector<onnxruntime::NodeArg*> outputDefPointers;
523+
};
524+
525+
auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap(
526+
graph,
527+
*indexedSubGraph,
528+
std::move(graphNodePropertyMap));
529+
530+
auto modelPath = graph.ModelPath();
531+
532+
const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs;
533+
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs;
534+
535+
std::vector<NodeInfo> nodesInfo;
536+
nodesInfo.reserve(indexedSubGraph->nodes.size());
537+
538+
std::vector<const onnxruntime::NodeArg*> subgraphInputs;
539+
subgraphInputs.reserve(subGraphInputArgNames.size());
540+
541+
std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
542+
subgraphOutputs.reserve(subGraphOutputArgNames.size());
543+
544+
std::vector<onnxruntime::NodeAttributes> nodeAttributes;
545+
nodeAttributes.reserve(indexedSubGraph->nodes.size());
546+
547+
std::vector<std::shared_ptr<onnxruntime::NodeArg>> intermediateNodeArgs;
548+
549+
for (size_t sortedNodeIndex : indexedSubGraph->nodes)
550+
{
551+
auto node = graph.GetNode(sortedNodeIndex);
552+
553+
nodeAttributes.push_back(node->GetAttributes());
554+
555+
NodeInfo nodeInfo{};
556+
nodeInfo.name = node->Name();
557+
nodeInfo.opType = node->OpType();
558+
nodeInfo.description = node->Description();
559+
nodeInfo.domain = node->Domain();
560+
nodeInfo.attributes = node->GetAttributes();
561+
nodeInfo.inputDefPointers.reserve(node->InputDefs().size());
562+
nodeInfo.outputDefPointers.reserve(node->OutputDefs().size());
563+
564+
for (const onnxruntime::NodeArg* inputDef : node->InputDefs())
565+
{
566+
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(inputDef->Name(), inputDef->TypeAsProto()));
567+
nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get());
568+
}
569+
570+
for (const onnxruntime::NodeArg* outputDef : node->OutputDefs())
571+
{
572+
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(outputDef->Name(), outputDef->TypeAsProto()));
573+
nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get());
574+
}
575+
576+
nodesInfo.push_back(std::move(nodeInfo));
577+
}
578+
579+
for (const std::string& graphInputName : subGraphInputArgNames)
580+
{
581+
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
582+
}
583+
584+
for (const std::string& graphOutputName : subGraphOutputArgNames)
585+
{
586+
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
587+
}
588+
589+
// We need to keep the initializers alive since they will be freed once the nodes are removed from the graph
590+
std::vector<ONNX_NAMESPACE::TensorProto> ownedInitializers;
591+
ownedInitializers.reserve(isInitializerTransferable.size());
592+
593+
for (auto& kvp : isInitializerTransferable)
594+
{
595+
ONNX_NAMESPACE::TensorProto tensorProto;
596+
tensorProto.set_data_type(kvp.second.first->data_type());
597+
tensorProto.set_raw_data(kvp.second.first->raw_data());
598+
tensorProto.set_name(kvp.second.first->name());
599+
600+
for (int i = 0; i < kvp.second.first->dims_size(); ++i)
601+
{
602+
tensorProto.add_dims(kvp.second.first->dims(i));
603+
}
604+
ownedInitializers.push_back(std::move(tensorProto));
605+
kvp.second.first = &ownedInitializers.back();
606+
}
607+
608+
// lamda captures for the kernel registration
609+
auto fused_kernel_func = [
610+
indexedSubGraph,
611+
&modelPath,
612+
nodesInfo = std::move(nodesInfo),
613+
intermediateNodeArgs = std::move(intermediateNodeArgs),
614+
subgraphInputs = std::move(subgraphInputs),
615+
subgraphOutputs = std::move(subgraphOutputs),
616+
partitionNodePropsMap = std::move(partitionNodePropsMap),
617+
ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
618+
{
619+
std::vector<std::shared_ptr<onnxruntime::Node>> subgraphNodes;
620+
subgraphNodes.reserve(nodesInfo.size());
621+
622+
for (const NodeInfo& nodeInfo : nodesInfo)
623+
{
624+
subgraphNodes.emplace_back(std::make_shared<onnxruntime::Node>(
625+
nodeInfo.name,
626+
nodeInfo.opType,
627+
nodeInfo.description,
628+
nodeInfo.inputDefPointers,
629+
nodeInfo.outputDefPointers,
630+
&nodeInfo.attributes,
631+
nodeInfo.domain));
632+
}
633+
634+
out.reset(CreateRuntimeFusedGraphKernel(
635+
info,
636+
indexedSubGraph,
637+
modelPath,
638+
std::move(subgraphNodes),
639+
std::move(subgraphInputs),
640+
std::move(subgraphOutputs),
641+
std::move(intermediateNodeArgs),
642+
std::move(partitionNodePropsMap),
643+
std::move(ownedInitializers)));
644+
return Status::OK();
645+
};
646+
647+
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
648+
onnxruntime::KernelDefBuilder builder;
649+
builder.SetName(indexedSubGraph->GetMetaDef()->name)
650+
.SetDomain(indexedSubGraph->GetMetaDef()->domain)
651+
.SinceVersion(indexedSubGraph->GetMetaDef()->since_version)
652+
.Provider(onnxruntime::kDmlExecutionProvider);
653+
654+
// Force the CPU inputs to be allocated on the CPU
655+
for (int i = 0; i < subGraphInputArgNames.size(); ++i)
656+
{
657+
if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end())
658+
{
659+
builder.InputMemoryType(OrtMemTypeCPUInput, i);
660+
}
661+
}
662+
663+
ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func));
664+
665+
auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name);
666+
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);
667+
668+
graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode);
669+
}
504670
}
505671
}

onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper
8080
std::vector<uint8_t>&& isInputsUploadedByDmlEP,
8181
const GraphDescBuilder::GraphDesc& graphDesc,
8282
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);
83+
84+
void RegisterDynamicKernel(
85+
onnxruntime::Graph& graph,
86+
onnxruntime::KernelRegistry* registryForPartitionKernels,
87+
const ExecutionProviderImpl* providerImpl,
88+
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
89+
const std::unordered_set<std::string>& dynamicCpuInputMap,
90+
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
91+
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable);
8392
}
8493
}

0 commit comments

Comments
 (0)