Skip to content

[DML EP] Add dynamic graph compilation #17876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3dd6288
WIP
PatriceVignola Oct 1, 2023
a06548a
WIP
PatriceVignola Oct 3, 2023
eb33dc3
WIP
PatriceVignola Oct 6, 2023
32a2b95
Working implementation
PatriceVignola Oct 6, 2023
d69d399
Working version with good performance (no queue reuse)
PatriceVignola Oct 6, 2023
0252c13
Disable metacommands
PatriceVignola Oct 7, 2023
779197d
Add options to disable metacommands and enable dynamic graph fusions
PatriceVignola Oct 7, 2023
b809a5b
Remove unused variables
PatriceVignola Oct 7, 2023
bc322c1
Add check in case CPU inputs changed
PatriceVignola Oct 7, 2023
135d467
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
PatriceVignola Oct 7, 2023
f850a2d
Refactor
PatriceVignola Oct 7, 2023
339abde
More refactoring
PatriceVignola Oct 7, 2023
ebfe95f
Fix crash when empty shapes
PatriceVignola Oct 10, 2023
e5fa87e
Uncomment assert
PatriceVignola Oct 10, 2023
461ffd3
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
PatriceVignola Oct 13, 2023
5dc0aaf
Address PR comments
PatriceVignola Oct 16, 2023
2ac62ab
Address PR comments
PatriceVignola Oct 19, 2023
ba27e0d
Revert unneeded change
PatriceVignola Oct 19, 2023
b99184d
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
PatriceVignola Oct 19, 2023
95b4779
Small fix
PatriceVignola Oct 19, 2023
3ccfd06
Address PR comments
PatriceVignola Oct 19, 2023
1f8d9ea
Address PR comments
PatriceVignola Oct 23, 2023
c5c844c
Address PR comments
PatriceVignola Oct 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct OrtDmlApi {
/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
* (high power, low power, or default) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#pragma once
interface IMLOperatorRegistry;
interface IDMLDevice;
interface ID3D12CommandQueue;
interface ID3D12Resource;

#include "core/common/status.h"
#include "core/framework/data_transfer.h"
Expand All @@ -28,7 +31,8 @@ namespace Dml
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
bool enableMetacommands = true);
bool enableMetacommands,
bool enableDynamicGraphFusion);

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
#include <functional>
#include <variant>
#include <optional>
#include <wrl/client.h>

#include "core/framework/op_kernel.h"
#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h"

struct AbstractOperatorDesc;
interface IMLOperatorTensor;
interface IDMLOperator;
struct DML_INPUT_GRAPH_EDGE_DESC;
struct DML_OUTPUT_GRAPH_EDGE_DESC;
struct DML_INTERMEDIATE_GRAPH_EDGE_DESC;
Expand Down Expand Up @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,22 +491,24 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const onnxruntime::Node& node,
MLOperatorTensorGetter& constantInputGetter,
const void* executionHandle,
const EdgeShapes* inputShapesOverrides,
/*out*/ EdgeShapes* outputShapes,
/*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo
)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);

// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes);

// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
&outputShapes,
inputShapesOverrides,
outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace Windows::AI::MachineLearning::Adapter
{
// edges and unused edges have an empty array of dimensions.
class EdgeShapes
{
public:
EdgeShapes() = default;

EdgeShapes(size_t count) : m_shapes(count) {}

const std::vector<uint32_t>& GetShape(size_t edgeIndex) const
{
return m_shapes[edgeIndex];
}

std::vector<uint32_t>& GetMutableShape(size_t edgeIndex)
{
return m_shapes[edgeIndex];
}

size_t EdgeCount() const { return m_shapes.size(); }

void Reset(size_t edge_count)
{
m_shapes.clear();
m_shapes.resize(edge_count);
}

bool operator!=(const EdgeShapes& other) const noexcept
{
return (m_shapes != other.m_shapes);
}

private:
std::vector<std::vector<uint32_t>> m_shapes;
};
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "DmlGraphFusionHelper.h"

#include "DmlRuntimeFusedGraphKernel.h"

namespace Dml
{
Expand Down Expand Up @@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper

graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode);
}

void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable)
{
struct NodeInfo
{
std::string name;
std::string opType;
std::string description;
std::string domain;
onnxruntime::NodeAttributes attributes;
std::vector<onnxruntime::NodeArg*> inputDefPointers;
std::vector<onnxruntime::NodeArg*> outputDefPointers;
};

auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap(
graph,
*indexedSubGraph,
std::move(graphNodePropertyMap));

auto modelPath = graph.ModelPath();

const gsl::span<const std::string> subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs;
const gsl::span<const std::string> subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs;

std::vector<NodeInfo> nodesInfo;
nodesInfo.reserve(indexedSubGraph->nodes.size());

std::vector<const onnxruntime::NodeArg*> subgraphInputs;
subgraphInputs.reserve(subGraphInputArgNames.size());

std::vector<const onnxruntime::NodeArg*> subgraphOutputs;
subgraphOutputs.reserve(subGraphOutputArgNames.size());

std::vector<onnxruntime::NodeAttributes> nodeAttributes;
nodeAttributes.reserve(indexedSubGraph->nodes.size());

std::vector<std::shared_ptr<onnxruntime::NodeArg>> intermediateNodeArgs;

for (size_t sortedNodeIndex : indexedSubGraph->nodes)
{
auto node = graph.GetNode(sortedNodeIndex);

nodeAttributes.push_back(node->GetAttributes());

NodeInfo nodeInfo{};
nodeInfo.name = node->Name();
nodeInfo.opType = node->OpType();
nodeInfo.description = node->Description();
nodeInfo.domain = node->Domain();
nodeInfo.attributes = node->GetAttributes();
nodeInfo.inputDefPointers.reserve(node->InputDefs().size());
nodeInfo.outputDefPointers.reserve(node->OutputDefs().size());

for (const onnxruntime::NodeArg* inputDef : node->InputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(inputDef->Name(), inputDef->TypeAsProto()));
nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get());
}

for (const onnxruntime::NodeArg* outputDef : node->OutputDefs())
{
intermediateNodeArgs.emplace_back(std::make_shared<onnxruntime::NodeArg>(outputDef->Name(), outputDef->TypeAsProto()));
nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get());
}

nodesInfo.push_back(std::move(nodeInfo));
}

for (const std::string& graphInputName : subGraphInputArgNames)
{
subgraphInputs.push_back(graph.GetNodeArg(graphInputName));
}

for (const std::string& graphOutputName : subGraphOutputArgNames)
{
subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName));
}

// We need to keep the initializers alive since they will be freed once the nodes are removed from the graph
std::vector<ONNX_NAMESPACE::TensorProto> ownedInitializers;
ownedInitializers.reserve(isInitializerTransferable.size());

for (auto& kvp : isInitializerTransferable)
{
ONNX_NAMESPACE::TensorProto tensorProto;
tensorProto.set_data_type(kvp.second.first->data_type());
tensorProto.set_raw_data(kvp.second.first->raw_data());
tensorProto.set_name(kvp.second.first->name());

for (int i = 0; i < kvp.second.first->dims_size(); ++i)
{
tensorProto.add_dims(kvp.second.first->dims(i));
}
ownedInitializers.push_back(std::move(tensorProto));
kvp.second.first = &ownedInitializers.back();
}

// lamda captures for the kernel registration
auto fused_kernel_func = [
indexedSubGraph,
&modelPath,
nodesInfo = std::move(nodesInfo),
intermediateNodeArgs = std::move(intermediateNodeArgs),
subgraphInputs = std::move(subgraphInputs),
subgraphOutputs = std::move(subgraphOutputs),
partitionNodePropsMap = std::move(partitionNodePropsMap),
ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr<onnxruntime::OpKernel>& out) mutable ->onnxruntime::Status
{
std::vector<std::shared_ptr<onnxruntime::Node>> subgraphNodes;
subgraphNodes.reserve(nodesInfo.size());

for (const NodeInfo& nodeInfo : nodesInfo)
{
subgraphNodes.emplace_back(std::make_shared<onnxruntime::Node>(
nodeInfo.name,
nodeInfo.opType,
nodeInfo.description,
nodeInfo.inputDefPointers,
nodeInfo.outputDefPointers,
&nodeInfo.attributes,
nodeInfo.domain));
}

out.reset(CreateRuntimeFusedGraphKernel(
info,
indexedSubGraph,
modelPath,
std::move(subgraphNodes),
std::move(subgraphInputs),
std::move(subgraphOutputs),
std::move(intermediateNodeArgs),
std::move(partitionNodePropsMap),
std::move(ownedInitializers)));
return Status::OK();
};

// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
onnxruntime::KernelDefBuilder builder;
builder.SetName(indexedSubGraph->GetMetaDef()->name)
.SetDomain(indexedSubGraph->GetMetaDef()->domain)
.SinceVersion(indexedSubGraph->GetMetaDef()->since_version)
.Provider(onnxruntime::kDmlExecutionProvider);

// Force the CPU inputs to be allocated on the CPU
for (int i = 0; i < subGraphInputArgNames.size(); ++i)
{
if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end())
{
builder.InputMemoryType(OrtMemTypeCPUInput, i);
}
}

ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func));

auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name);
fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider);

graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper
std::vector<uint8_t>&& isInputsUploadedByDmlEP,
const GraphDescBuilder::GraphDesc& graphDesc,
Microsoft::WRL::ComPtr<IDMLCompiledOperator> compiledExecutionPlanOperator);

void RegisterDynamicKernel(
onnxruntime::Graph& graph,
onnxruntime::KernelRegistry* registryForPartitionKernels,
const ExecutionProviderImpl* providerImpl,
std::unordered_map<const onnxruntime::Node*, GraphNodeProperties> graphNodePropertyMap,
const std::unordered_set<std::string>& dynamicCpuInputMap,
std::shared_ptr<const onnxruntime::IndexedSubGraph> indexedSubGraph,
std::unordered_map<std::string, std::pair<const ONNX_NAMESPACE::TensorProto*, bool>>&& isInitializerTransferable);
}
}
Loading