Skip to content

[ET-VK] Dynamic shape support in Vulkan Backend #2367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 73 additions & 2 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <executorch/runtime/platform/compiler.h>
#include <executorch/runtime/platform/profiler.h>

Expand Down Expand Up @@ -195,6 +196,68 @@ class GraphBuilder {
}
};

//
// Execution tools
//

bool maybe_resize_input(
ComputeGraph* graph,
const size_t input_i,
exec_aten::Tensor& et_tensor) {
ValueRef in_tensor_ref = graph->inputs()[input_i].value;
vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor();

ET_CHECK_MSG(
et_tensor.dim() == in_tensor.sizes().size(),
"Cannot resize input tensor: old ndim %zu does not match new ndim %zu",
static_cast<size_t>(in_tensor.sizes().size()),
static_cast<size_t>(et_tensor.dim()));

bool should_resize = false;
std::vector<int64_t> new_sizes(et_tensor.dim());
for (size_t i = 0; i < et_tensor.dim(); i++) {
if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) {
should_resize = true;
}
new_sizes.at(i) = et_tensor.sizes()[i];
}

if (should_resize) {
graph->resize_input(input_i, new_sizes);
}

ET_CHECK_MSG(
in_tensor.numel() == et_tensor.numel(),
"Vulkan tensor numel %zu does not match ET tensor numel %zu",
static_cast<size_t>(in_tensor.numel()),
static_cast<size_t>(et_tensor.numel()));

return should_resize;
}

void maybe_resize_output(
ComputeGraph* graph,
const size_t output_i,
exec_aten::Tensor& et_tensor) {
ValueRef out_tensor_ref = graph->outputs()[output_i].value;
vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor();

exec_aten::SizesType new_output_size[kTensorDimensionLimit];
size_t ndim = out_tensor.sizes().size();
for (int i = 0; i < ndim; ++i) {
new_output_size[i] = out_tensor.sizes()[i];
}

exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim};
Error err = resize_tensor(et_tensor, output_size);

ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor.");
}

//
// VulkanBackend class
//

class VulkanBackend final : public PyTorchBackendInterface {
public:
~VulkanBackend() override = default;
Expand Down Expand Up @@ -273,20 +336,28 @@ class VulkanBackend final : public PyTorchBackendInterface {
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);

const size_t num_inputs = compute_graph->inputs().size();
bool should_propagate_resize = false;
for (size_t i = 0; i < num_inputs; i++) {
bool was_resized =
maybe_resize_input(compute_graph, i, args[i]->toTensor());
should_propagate_resize = should_propagate_resize || was_resized;
compute_graph->copy_into_staging(
compute_graph->inputs()[i],
compute_graph->inputs()[i].staging,
args[i]->toTensor().const_data_ptr(),
args[i]->toTensor().numel());
}

if (should_propagate_resize) {
compute_graph->propagate_resize();
}
compute_graph->execute();

for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
// args holds inputs directly followed by outputs, so the i'th output
// for compute_graph corresponds to the (i + num_inputs)'th arg
compute_graph->copy_from_staging(
compute_graph->outputs()[i],
compute_graph->outputs()[i].staging,
args[num_inputs + i]->toTensor().mutable_data_ptr(),
args[num_inputs + i]->toTensor().numel());
}
Expand Down
21 changes: 17 additions & 4 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ ValueRef ComputeGraph::set_input_tensor(
vTensor& tensor = get_val(idx).toTensor();
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
add_staging_to_tensor_node(*this, staging_idx, idx);
inputs_.push_back(staging_idx);
inputs_.push_back({idx, staging_idx});
return staging_idx;
}
inputs_.push_back(idx);
inputs_.push_back({idx, kDummyValueRef});
return idx;
}

Expand All @@ -149,10 +149,10 @@ ValueRef ComputeGraph::set_output_tensor(
vTensor& tensor = get_val(idx).toTensor();
ValueRef staging_idx = add_staging(tensor.dtype(), tensor.gpu_numel());
add_tensor_to_staging_node(*this, idx, staging_idx);
outputs_.push_back(staging_idx);
outputs_.push_back({idx, staging_idx});
return staging_idx;
}
outputs_.push_back(idx);
outputs_.push_back({idx, kDummyValueRef});
return idx;
}

Expand Down Expand Up @@ -241,6 +241,19 @@ void ComputeGraph::execute() const {
fence.wait();
}

void ComputeGraph::resize_input(
const int64_t idx,
const std::vector<int64_t>& new_sizes) {
IOValueRef io_val = inputs_.at(idx);
get_val(io_val.value).toTensor().virtual_resize(new_sizes);
}

void ComputeGraph::propagate_resize() {
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
node->trigger_resize(this);
}
}

} // namespace vulkan
} // namespace native
} // namespace at
15 changes: 11 additions & 4 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class ComputeGraph final {
std::vector<std::unique_ptr<PrepackNode>> prepack_nodes_;
std::vector<std::unique_ptr<ExecuteNode>> execute_nodes_;

std::vector<ValueRef> inputs_;
std::vector<ValueRef> outputs_;
std::vector<IOValueRef> inputs_;
std::vector<IOValueRef> outputs_;

public:
//
Expand All @@ -80,11 +80,11 @@ class ComputeGraph final {
return context_.get();
}

inline std::vector<ValueRef>& inputs() {
inline std::vector<IOValueRef>& inputs() {
return inputs_;
}

inline std::vector<ValueRef>& outputs() {
inline std::vector<IOValueRef>& outputs() {
return outputs_;
}

Expand Down Expand Up @@ -201,6 +201,13 @@ class ComputeGraph final {

void encode_execute();
void execute() const;

//
// Dynamic Shape support
//

void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
void propagate_resize();
};

template <typename T>
Expand Down
8 changes: 6 additions & 2 deletions backends/vulkan/runtime/graph/ops/ExecuteNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ ExecuteNode::ExecuteNode(
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn,
const std::vector<ValueRef>& resize_args)
: shader_(shader),
global_workgroup_size_(global_workgroup_size),
local_workgroup_size_(local_workgroup_size),
args_(args),
params_(params) {
params_(params),
resize_fn_(resize_fn),
resize_args_(resize_args) {
graph.update_descriptor_counts(shader, /*execute = */ true);
}

Expand Down
17 changes: 16 additions & 1 deletion backends/vulkan/runtime/graph/ops/ExecuteNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,40 @@ class ExecuteNode final {
friend class ComputeGraph;

public:
using ResizeFunction = const std::function<void(
ComputeGraph*,
const std::vector<ArgGroup>&,
const std::vector<ValueRef>&)>;

ExecuteNode(
ComputeGraph& graph,
const api::ShaderInfo& shader,
const api::utils::uvec3& global_workgroup_size,
const api::utils::uvec3& local_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
const ResizeFunction& resize_fn = nullptr,
const std::vector<ValueRef>& resize_args = {});

~ExecuteNode() = default;

void encode(ComputeGraph* graph);

inline void trigger_resize(ComputeGraph* graph) {
if (resize_fn_ != nullptr) {
resize_fn_(graph, args_, resize_args_);
}
}

protected:
const api::ShaderInfo shader_;
const api::utils::uvec3 global_workgroup_size_;
const api::utils::uvec3 local_workgroup_size_;
const std::vector<ArgGroup> args_;
// TODO(T180906457): allow re-computing param buffers.
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
const ResizeFunction resize_fn_;
const std::vector<ValueRef> resize_args_;
};

} // namespace vulkan
Expand Down
28 changes: 27 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,28 @@ namespace at {
namespace native {
namespace vulkan {

void resize_binary_op_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
vTensor& out = graph->get_val(args[0].refs[0]).toTensor();
vTensor& self = graph->get_val(args[1].refs[0]).toTensor();
vTensor& other = graph->get_val(args[1].refs[1]).toTensor();

std::vector<int64_t> new_out_sizes(
std::max(self.sizes().size(), other.sizes().size()));

// Match the sizes in reverse because sizes are in NCHW order
for (int i = -1; i >= -new_out_sizes.size(); --i) {
new_out_sizes.at(new_out_sizes.size() + i) = std::max(
api::utils::val_at(i, self.sizes()),
api::utils::val_at(i, other.sizes()));
}

out.virtual_resize(new_out_sizes);
}

void add_binary_op_node(
ComputeGraph& graph,
const ValueRef in1,
Expand Down Expand Up @@ -52,12 +74,16 @@ void add_binary_op_node(
VK_KERNEL_FROM_STR(kernel_name.str()),
global_size,
local_size,
// Inputs and Outputs
{{out, api::MemoryAccessType::WRITE},
{{arg1, arg2}, api::MemoryAccessType::READ}},
// Shader params buffers
{t_out.gpu_sizes_ubo(),
t_in1.gpu_sizes_ubo(),
t_in2.gpu_sizes_ubo(),
graph.create_params_buffer(alpha_val)}));
graph.create_params_buffer(alpha_val)},
// Resizing
resize_binary_op_node));
}

#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \
Expand Down
15 changes: 8 additions & 7 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,14 @@ def process_getattr_node(self, node: Node) -> None:
self.create_tensor_values(node)

def process_output_node(self, node: Node) -> None:
if node.all_input_nodes[0] not in self.node_to_value_ids:
raise AssertionError(
"Cannot find input to output node in node_to_value_ids. This means the "
"output node is being serialized before its corresponding internal node "
"which is not allowed."
)
self.output_ids.append(self.node_to_value_ids[node.all_input_nodes[0]])
for out_node in node.all_input_nodes:
if out_node not in self.node_to_value_ids:
raise AssertionError(
"Cannot find input to output node in node_to_value_ids. This means "
"the output node is being serialized before its corresponding "
"internal node which is not allowed."
)
self.output_ids.append(self.node_to_value_ids[out_node])

def process_node(self, node: Node) -> None:
if node.op == "placeholder":
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def define_common_targets(is_fbcode = False):
":vk_delegate_schema",
":vulkan_graph_runtime",
"//executorch/runtime/backend:interface",
"//executorch/runtime/core/exec_aten/util:tensor_util",
],
define_static_target = False,
# VulkanBackend.cpp needs to compile with executor as whole
Expand Down
Loading