Skip to content

Extend support for scalars and scalar lists in Value class #2271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 168 additions & 155 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
#include <cstdlib> /* strtol */
#include <memory>
#include <type_traits>
#include <vector>

namespace torch {
namespace executor {
namespace vulkan {
namespace {

using namespace at::native::vulkan;

// Flatbuffer types
using VkGraphPtr = const vkgraph::VkGraph*;
using OpCallPtr = const vkgraph::OperatorCall*;
Expand All @@ -51,102 +54,194 @@ const uint8_t* getConstantDataPtr(
return constant_data + constant_bytes->offset();
}

using namespace at::native::vulkan;
api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
switch (vk_datatype) {
case (vkgraph::VkDataType::fp32): {
return api::kFloat;
}
}
}

GraphConfig generate_config() {
const uint32_t submit_frequency = UINT32_MAX;

const api::CommandPoolConfig cmd_config{
4u, // cmdPoolInitialSize
2u, // cmdPoolBatchSize
};

const api::DescriptorPoolConfig descriptor_pool_config{
1024u, // descriptorPoolMaxSets
1024u, // descriptorUniformBufferCount
1024u, // descriptorStorageBufferCount
1024u, // descriptorCombinedSamplerCount
1024u, // descriptorStorageImageCount
32u, // descriptorPileSizes
};

const api::QueryPoolConfig query_pool_config{};

const api::ContextConfig context_config{
submit_frequency, // cmdSubmitFrequency
cmd_config, // cmdPoolConfig
descriptor_pool_config, // descriptorPoolConfig
query_pool_config, // queryPoolConfig
};

const GraphConfig graph_config{
context_config,
};

return graph_config;
}

class GraphBuilder {
ComputeGraph* compute_graph_;
VkGraphPtr flatbuffer_;
const uint8_t* constant_data_;

std::unordered_map<uint32_t, ValueRef> ref_mapping_;

class VulkanBackend final : public PyTorchBackendInterface {
public:
~VulkanBackend() override = default;
explicit GraphBuilder(
ComputeGraph* compute_graph,
VkGraphPtr flatbuffer,
const uint8_t* constant_data)
: compute_graph_(compute_graph),
flatbuffer_(flatbuffer),
constant_data_(constant_data),
ref_mapping_() {}

bool fb_id_exists(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);

bool is_available() const override {
return true;
return found_ref != ref_mapping_.end();
}

api::ScalarType get_scalar_type(
const vkgraph::VkDataType& vk_datatype) const {
switch (vk_datatype) {
case (vkgraph::VkDataType::fp32): {
return api::kFloat;
}
}
ValueRef get_fb_id_valueref(const uint32_t fb_id) {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping_.find(fb_id);

ET_CHECK_MSG(
found_ref != ref_mapping_.end(),
"Trying to extract a value that hasn't yet been added to the graph.");

return found_ref->second;
}

ValueRef get_value_ref(
const uint32_t value_id,
VkGraphPtr flatbuffer_graph,
ComputeGraph* compute_graph,
std::unordered_map<uint32_t, ValueRef>& ref_mapping,
VkValuesVector value_mapping,
const uint8_t* constant_data) const {
const std::unordered_map<uint32_t, ValueRef>::iterator found_ref =
ref_mapping.find(value_id);
void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) {
const api::ScalarType& dtype = get_scalar_type(tensor_fb->datatype());

UIntVector dims_fb = tensor_fb->dims();
const std::vector<int64_t> dims_vector(dims_fb->cbegin(), dims_fb->cend());

if (found_ref != ref_mapping.end()) {
return found_ref->second;
ValueRef ref;
if (tensor_fb->constant_id() >= 0) {
const uint8_t* tensor_data = getConstantDataPtr(
flatbuffer_, tensor_fb->constant_id(), constant_data_);

ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data);
} else {
ref = compute_graph_->add_tensor(
dims_vector, dtype, tensor_fb->mem_obj_id());
}

VkValuePtr vk_value = value_mapping->Get(value_id);
VkTensorPtr vk_tensor = vk_value->value();
ref_mapping_[fb_id] = ref;
}

template <typename T>
typename std::enable_if<is_valid_scalar_type<T>::value, void>::type
add_scalar_to_graph(const uint32_t fb_id, T value) {
ValueRef ref = compute_graph_->add_scalar(value);
ref_mapping_[fb_id] = ref;
}

void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) {
const auto fb_str = value->value_as_String()->string_val();
std::string string(fb_str->cbegin(), fb_str->cend());
ValueRef ref = compute_graph_->add_string(std::move(string));
ref_mapping_[fb_id] = ref;
}

void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) {
ET_CHECK_MSG(
vk_tensor->constant_id() >= 0,
"Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d",
vk_tensor->constant_id());
!fb_id_exists(fb_id),
"Trying to add a value that has already been added to the graph.");

switch (value->value_type()) {
case vkgraph::GraphTypes::Int:
add_scalar_to_graph(fb_id, value->value_as_Int()->int_val());
break;
case vkgraph::GraphTypes::Double:
add_scalar_to_graph(fb_id, value->value_as_Double()->double_val());
break;
case vkgraph::GraphTypes::Bool:
add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val());
break;
case vkgraph::GraphTypes::VkTensor:
add_tensor_to_graph(fb_id, value->value_as_VkTensor());
break;
case vkgraph::GraphTypes::String:
add_string_to_graph(fb_id, value);
break;
default:
ET_CHECK_MSG(false, "Unsupported value type.");
}
}

const api::ScalarType& tensor_dtype =
get_scalar_type(vk_tensor->datatype());
void build_graph() {
// First, add all values to the graph
for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) {
VkValuePtr value = flatbuffer_->values()->Get(fb_id);
add_value_to_graph(fb_id, value);
}

UIntVector tensor_dims_fb = vk_tensor->dims();
const std::vector<int64_t> tensor_dims_vector(
tensor_dims_fb->cbegin(), tensor_dims_fb->cend());
// Parse the inputs
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
const ValueRef ref = get_fb_id_valueref(fb_id);
compute_graph_->set_input_tensor(ref);
}

const uint8_t* tensor_data = getConstantDataPtr(
flatbuffer_graph, vk_tensor->constant_id(), constant_data);
// Parse the operators
for (OpCallPtr op_call : *(flatbuffer_->chain())) {
std::string op_name = op_call->name()->str();
ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str());

const ValueRef value_ref = compute_graph->add_tensorref(
tensor_dims_vector, tensor_dtype, tensor_data);
const std::vector<int> arg_fb_ids(
op_call->args()->cbegin(), op_call->args()->cend());

ref_mapping[value_id] = value_ref;
std::vector<ValueRef> args;
for (const int arg_fb_id : arg_fb_ids) {
args.push_back(get_fb_id_valueref(arg_fb_id));
}

return value_ref;
auto vkFn = getOpsFn(op_name);
vkFn(*compute_graph_, args);
}

// Parse the outputs
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
const ValueRef ref = get_fb_id_valueref(fb_id);
compute_graph_->set_output_tensor(ref);
}
}
};

GraphConfig generate_config() const {
const uint32_t submit_frequency = UINT32_MAX;

const api::CommandPoolConfig cmd_config{
4u, // cmdPoolInitialSize
2u, // cmdPoolBatchSize
};

const api::DescriptorPoolConfig descriptor_pool_config{
1024u, // descriptorPoolMaxSets
1024u, // descriptorUniformBufferCount
1024u, // descriptorStorageBufferCount
1024u, // descriptorCombinedSamplerCount
1024u, // descriptorStorageImageCount
32u, // descriptorPileSizes
};

const api::QueryPoolConfig query_pool_config{};

const api::ContextConfig context_config{
submit_frequency, // cmdSubmitFrequency
cmd_config, // cmdPoolConfig
descriptor_pool_config, // descriptorPoolConfig
query_pool_config, // queryPoolConfig
};

const GraphConfig graph_config{
context_config,
};

return graph_config;
class VulkanBackend final : public PyTorchBackendInterface {
public:
~VulkanBackend() override = default;

bool is_available() const override {
// TODO(ssjia): replace with an actual Vulkan runtime availability check
return true;
}

__ET_NODISCARD Error
compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const {
Result<VulkanDelegateHeader> header =
VulkanDelegateHeader::Parse(buffer_pointer);

const uint8_t* flatbuffer_data = nullptr;
const uint8_t* constant_data = nullptr;

Expand All @@ -169,92 +264,10 @@ class VulkanBackend final : public PyTorchBackendInterface {

VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data);

// Mapping from serialized VkValue ids to compute graph ValueRefs
// This will be populated as the compute graph is built
std::unordered_map<uint32_t, ValueRef> ref_mapping;

// A vector which acts as a mapping from VkValue ids (vector indices) to
// VkValues
VkValuesVector value_mapping = flatbuffer_graph->values();
GraphBuilder builder =
GraphBuilder(compute_graph, flatbuffer_graph, constant_data);

// 1. Add all inputs (and corresponding tensors) to the compute graph
UIntVector input_ids = flatbuffer_graph->input_ids();

for (size_t input_index = 0; input_index < input_ids->size();
++input_index) {
const uint32_t input_id = input_ids->Get(input_index);
VkValuePtr input_vk_value = value_mapping->Get(input_id);

VkTensorPtr input_vk_tensor = input_vk_value->value();

ET_CHECK_MSG(
input_vk_tensor->constant_id() < 0,
"Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d",
input_index,
input_id,
input_vk_tensor->constant_id());

const api::ScalarType& input_dtype =
get_scalar_type(input_vk_tensor->datatype());

UIntVector input_dims_fb = input_vk_tensor->dims();
const std::vector<int64_t> input_dims_vector(
input_dims_fb->cbegin(), input_dims_fb->cend());

const ValueRef input_ref = compute_graph->add_tensor(
input_dims_vector, input_dtype, input_vk_tensor->mem_obj_id());

ref_mapping[input_id] = input_ref;
compute_graph->set_input_tensor(input_ref);
}

// 2. Add all ops to the graph
// TODO: Generalize for ops that don't have 2 inputs and 1 output.
for (OpCallPtr op_call : *(flatbuffer_graph->chain())) {
std::string op_name = op_call->name()->str();

ET_CHECK_MSG(
op_call->args() != nullptr && op_call->args()->size() == 3,
"Vulkan currently only supports OperatorCall with 3 args");
const auto arg_ids = op_call->args()->data();

const uint32_t input1_id = arg_ids[0];
const uint32_t input2_id = arg_ids[1];
const uint32_t output_id = arg_ids[2];

const ValueRef input1_ref = get_value_ref(
input1_id,
flatbuffer_graph,
compute_graph,
ref_mapping,
value_mapping,
constant_data);

const ValueRef input2_ref = get_value_ref(
input2_id,
flatbuffer_graph,
compute_graph,
ref_mapping,
value_mapping,
constant_data);

ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str());
auto vkFn = getOpsFn(op_name);
const at::native::vulkan::ValueRef output_ref = vkFn(
*compute_graph,
{input1_ref,
input2_ref,
1,
value_mapping->Get(output_id)->value()->mem_obj_id()});

ref_mapping[output_id] = output_ref;
}

// 3. Add all outputs to the compute graph
for (const uint32_t output_id : *flatbuffer_graph->output_ids()) {
const ValueRef output_ref = ref_mapping[output_id];
compute_graph->set_output_tensor(output_ref);
}
builder.build_graph();

compute_graph->encode_prepack();
compute_graph->prepack();
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ ValueRef ComputeGraph::add_staging(
return idx;
}

ValueRef ComputeGraph::add_string(std::string&& str) {
ValueRef idx(static_cast<int>(values_.size()));
values_.emplace_back(std::move(str));
return idx;
}

ValueRef ComputeGraph::set_input_tensor(
const ValueRef idx,
const bool use_staging) {
Expand Down
Loading