diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 9c554a232c9..1222ee38e5c 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -23,12 +23,15 @@ #include /* strtol */ #include #include +#include namespace torch { namespace executor { namespace vulkan { namespace { +using namespace at::native::vulkan; + // Flatbuffer types using VkGraphPtr = const vkgraph::VkGraph*; using OpCallPtr = const vkgraph::OperatorCall*; @@ -51,102 +54,194 @@ const uint8_t* getConstantDataPtr( return constant_data + constant_bytes->offset(); } -using namespace at::native::vulkan; +api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { + switch (vk_datatype) { + case (vkgraph::VkDataType::fp32): { + return api::kFloat; + } + } +} + +GraphConfig generate_config() { + const uint32_t submit_frequency = UINT32_MAX; + + const api::CommandPoolConfig cmd_config{ + 4u, // cmdPoolInitialSize + 2u, // cmdPoolBatchSize + }; + + const api::DescriptorPoolConfig descriptor_pool_config{ + 1024u, // descriptorPoolMaxSets + 1024u, // descriptorUniformBufferCount + 1024u, // descriptorStorageBufferCount + 1024u, // descriptorCombinedSamplerCount + 1024u, // descriptorStorageImageCount + 32u, // descriptorPileSizes + }; + + const api::QueryPoolConfig query_pool_config{}; + + const api::ContextConfig context_config{ + submit_frequency, // cmdSubmitFrequency + cmd_config, // cmdPoolConfig + descriptor_pool_config, // descriptorPoolConfig + query_pool_config, // queryPoolConfig + }; + + const GraphConfig graph_config{ + context_config, + }; + + return graph_config; +} + +class GraphBuilder { + ComputeGraph* compute_graph_; + VkGraphPtr flatbuffer_; + const uint8_t* constant_data_; + + std::unordered_map ref_mapping_; -class VulkanBackend final : public PyTorchBackendInterface { public: - ~VulkanBackend() override = default; + explicit GraphBuilder( + ComputeGraph* compute_graph, + VkGraphPtr flatbuffer, + const uint8_t* constant_data) + : compute_graph_(compute_graph), + flatbuffer_(flatbuffer), + constant_data_(constant_data), + ref_mapping_() {} + + bool fb_id_exists(const uint32_t fb_id) { + const std::unordered_map::iterator found_ref = + ref_mapping_.find(fb_id); - bool is_available() const override { - return true; + return found_ref != ref_mapping_.end(); } - api::ScalarType get_scalar_type( - const vkgraph::VkDataType& vk_datatype) const { - switch (vk_datatype) { - case (vkgraph::VkDataType::fp32): { - return api::kFloat; - } - } + ValueRef get_fb_id_valueref(const uint32_t fb_id) { + const std::unordered_map::iterator found_ref = + ref_mapping_.find(fb_id); + + ET_CHECK_MSG( + found_ref != ref_mapping_.end(), + "Trying to extract a value that hasn't yet been added to the graph."); + + return found_ref->second; } - ValueRef get_value_ref( - const uint32_t value_id, - VkGraphPtr flatbuffer_graph, - ComputeGraph* compute_graph, - std::unordered_map& ref_mapping, - VkValuesVector value_mapping, - const uint8_t* constant_data) const { - const std::unordered_map::iterator found_ref = - ref_mapping.find(value_id); + void add_tensor_to_graph(const uint32_t fb_id, VkTensorPtr tensor_fb) { + const api::ScalarType& dtype = get_scalar_type(tensor_fb->datatype()); + + UIntVector dims_fb = tensor_fb->dims(); + const std::vector dims_vector(dims_fb->cbegin(), dims_fb->cend()); - if (found_ref != ref_mapping.end()) { - return found_ref->second; + ValueRef ref; + if (tensor_fb->constant_id() >= 0) { + const uint8_t* tensor_data = getConstantDataPtr( + flatbuffer_, tensor_fb->constant_id(), constant_data_); + + ref = compute_graph_->add_tensorref(dims_vector, dtype, tensor_data); + } else { + ref = compute_graph_->add_tensor( + dims_vector, dtype, tensor_fb->mem_obj_id()); } - VkValuePtr vk_value = value_mapping->Get(value_id); - VkTensorPtr vk_tensor = vk_value->value(); + ref_mapping_[fb_id] = ref; + } + + template + typename std::enable_if::value, void>::type + add_scalar_to_graph(const uint32_t fb_id, T value) { + ValueRef ref = compute_graph_->add_scalar(value); + ref_mapping_[fb_id] = ref; + } + + void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) { + const auto fb_str = value->value_as_String()->string_val(); + std::string string(fb_str->cbegin(), fb_str->cend()); + ValueRef ref = compute_graph_->add_string(std::move(string)); + ref_mapping_[fb_id] = ref; + } + void add_value_to_graph(const uint32_t fb_id, VkValuePtr value) { ET_CHECK_MSG( - vk_tensor->constant_id() >= 0, - "Only constant buffers are supported when adding tensors to compute graph (indicated by constant_id < 0), but got constant_id of %d", - vk_tensor->constant_id()); + !fb_id_exists(fb_id), + "Trying to add a value that has already been added to the graph."); + + switch (value->value_type()) { + case vkgraph::GraphTypes::Int: + add_scalar_to_graph(fb_id, value->value_as_Int()->int_val()); + break; + case vkgraph::GraphTypes::Double: + add_scalar_to_graph(fb_id, value->value_as_Double()->double_val()); + break; + case vkgraph::GraphTypes::Bool: + add_scalar_to_graph(fb_id, value->value_as_Bool()->bool_val()); + break; + case vkgraph::GraphTypes::VkTensor: + add_tensor_to_graph(fb_id, value->value_as_VkTensor()); + break; + case vkgraph::GraphTypes::String: + add_string_to_graph(fb_id, value); + break; + default: + ET_CHECK_MSG(false, "Unsupported value type."); + } + } - const api::ScalarType& tensor_dtype = - get_scalar_type(vk_tensor->datatype()); + void build_graph() { + // First, add all values to the graph + for (uint32_t fb_id = 0; fb_id < flatbuffer_->values()->size(); ++fb_id) { + VkValuePtr value = flatbuffer_->values()->Get(fb_id); + add_value_to_graph(fb_id, value); + } - UIntVector tensor_dims_fb = vk_tensor->dims(); - const std::vector tensor_dims_vector( - tensor_dims_fb->cbegin(), tensor_dims_fb->cend()); + // Parse the inputs + for (const uint32_t fb_id : *flatbuffer_->input_ids()) { + const ValueRef ref = get_fb_id_valueref(fb_id); + compute_graph_->set_input_tensor(ref); + } - const uint8_t* tensor_data = getConstantDataPtr( - flatbuffer_graph, vk_tensor->constant_id(), constant_data); + // Parse the operators + for (OpCallPtr op_call : *(flatbuffer_->chain())) { + std::string op_name = op_call->name()->str(); + ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str()); - const ValueRef value_ref = compute_graph->add_tensorref( - tensor_dims_vector, tensor_dtype, tensor_data); + const std::vector arg_fb_ids( + op_call->args()->cbegin(), op_call->args()->cend()); - ref_mapping[value_id] = value_ref; + std::vector args; + for (const int arg_fb_id : arg_fb_ids) { + args.push_back(get_fb_id_valueref(arg_fb_id)); + } - return value_ref; + auto vkFn = getOpsFn(op_name); + vkFn(*compute_graph_, args); + } + + // Parse the outputs + for (const uint32_t fb_id : *flatbuffer_->output_ids()) { + const ValueRef ref = get_fb_id_valueref(fb_id); + compute_graph_->set_output_tensor(ref); + } } +}; - GraphConfig generate_config() const { - const uint32_t submit_frequency = UINT32_MAX; - - const api::CommandPoolConfig cmd_config{ - 4u, // cmdPoolInitialSize - 2u, // cmdPoolBatchSize - }; - - const api::DescriptorPoolConfig descriptor_pool_config{ - 1024u, // descriptorPoolMaxSets - 1024u, // descriptorUniformBufferCount - 1024u, // descriptorStorageBufferCount - 1024u, // descriptorCombinedSamplerCount - 1024u, // descriptorStorageImageCount - 32u, // descriptorPileSizes - }; - - const api::QueryPoolConfig query_pool_config{}; - - const api::ContextConfig context_config{ - submit_frequency, // cmdSubmitFrequency - cmd_config, // cmdPoolConfig - descriptor_pool_config, // descriptorPoolConfig - query_pool_config, // queryPoolConfig - }; - - const GraphConfig graph_config{ - context_config, - }; - - return graph_config; +class VulkanBackend final : public PyTorchBackendInterface { + public: + ~VulkanBackend() override = default; + + bool is_available() const override { + // TODO(ssjia): replace with an actual Vulkan runtime availability check + return true; } __ET_NODISCARD Error compileModel(const void* buffer_pointer, ComputeGraph* compute_graph) const { Result header = VulkanDelegateHeader::Parse(buffer_pointer); + const uint8_t* flatbuffer_data = nullptr; const uint8_t* constant_data = nullptr; @@ -169,92 +264,10 @@ class VulkanBackend final : public PyTorchBackendInterface { VkGraphPtr flatbuffer_graph = vkgraph::GetVkGraph(flatbuffer_data); - // Mapping from serialized VkValue ids to compute graph ValueRefs - // This will be populated as the compute graph is built - std::unordered_map ref_mapping; - - // A vector which acts as a mapping from VkValue ids (vector indices) to - // VkValues - VkValuesVector value_mapping = flatbuffer_graph->values(); + GraphBuilder builder = + GraphBuilder(compute_graph, flatbuffer_graph, constant_data); - // 1. Add all inputs (and corresponding tensors) to the compute graph - UIntVector input_ids = flatbuffer_graph->input_ids(); - - for (size_t input_index = 0; input_index < input_ids->size(); - ++input_index) { - const uint32_t input_id = input_ids->Get(input_index); - VkValuePtr input_vk_value = value_mapping->Get(input_id); - - VkTensorPtr input_vk_tensor = input_vk_value->value(); - - ET_CHECK_MSG( - input_vk_tensor->constant_id() < 0, - "Expected constant buffer index for input at index %zu with id %d to be < 0 (since it is non-constant), but got: %d", - input_index, - input_id, - input_vk_tensor->constant_id()); - - const api::ScalarType& input_dtype = - get_scalar_type(input_vk_tensor->datatype()); - - UIntVector input_dims_fb = input_vk_tensor->dims(); - const std::vector input_dims_vector( - input_dims_fb->cbegin(), input_dims_fb->cend()); - - const ValueRef input_ref = compute_graph->add_tensor( - input_dims_vector, input_dtype, input_vk_tensor->mem_obj_id()); - - ref_mapping[input_id] = input_ref; - compute_graph->set_input_tensor(input_ref); - } - - // 2. Add all ops to the graph - // TODO: Generalize for ops that don't have 2 inputs and 1 output. - for (OpCallPtr op_call : *(flatbuffer_graph->chain())) { - std::string op_name = op_call->name()->str(); - - ET_CHECK_MSG( - op_call->args() != nullptr && op_call->args()->size() == 3, - "Vulkan currently only supports OperatorCall with 3 args"); - const auto arg_ids = op_call->args()->data(); - - const uint32_t input1_id = arg_ids[0]; - const uint32_t input2_id = arg_ids[1]; - const uint32_t output_id = arg_ids[2]; - - const ValueRef input1_ref = get_value_ref( - input1_id, - flatbuffer_graph, - compute_graph, - ref_mapping, - value_mapping, - constant_data); - - const ValueRef input2_ref = get_value_ref( - input2_id, - flatbuffer_graph, - compute_graph, - ref_mapping, - value_mapping, - constant_data); - - ET_CHECK_MSG(hasOpsFn(op_name), "Missing operator: %s", op_name.c_str()); - auto vkFn = getOpsFn(op_name); - const at::native::vulkan::ValueRef output_ref = vkFn( - *compute_graph, - {input1_ref, - input2_ref, - 1, - value_mapping->Get(output_id)->value()->mem_obj_id()}); - - ref_mapping[output_id] = output_ref; - } - - // 3. Add all outputs to the compute graph - for (const uint32_t output_id : *flatbuffer_graph->output_ids()) { - const ValueRef output_ref = ref_mapping[output_id]; - compute_graph->set_output_tensor(output_ref); - } + builder.build_graph(); compute_graph->encode_prepack(); compute_graph->prepack(); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 647371424c4..0900dfb9c13 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -77,6 +77,12 @@ ValueRef ComputeGraph::add_staging( return idx; } +ValueRef ComputeGraph::add_string(std::string&& str) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(str)); + return idx; +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index ec8d3ba1db4..a45e449ae2d 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -28,6 +28,19 @@ namespace at { namespace native { namespace vulkan { +// Define valid scalar types that the Value class can accept +template +struct is_valid_scalar_type : std::false_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + +template <> +struct is_valid_scalar_type : std::true_type {}; + /* * This is the core data structure used to execute Vulkan models in graph mode. * As opposed to ATen/eager mode where a command buffer is encoded every @@ -123,6 +136,16 @@ class ComputeGraph final { const void* const data); ValueRef add_staging(const api::ScalarType dtype, const size_t numel); + template + typename std::enable_if::value, ValueRef>::type + add_scalar_list(std::vector&& values); + + template + typename std::enable_if::value, ValueRef>::type + add_scalar(T value); + + ValueRef add_string(std::string&& str); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); @@ -163,6 +186,22 @@ class ComputeGraph final { void execute() const; }; +template +inline typename std::enable_if::value, ValueRef>::type +ComputeGraph::add_scalar_list(std::vector&& values) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(values)); + return idx; +} + +template +inline typename std::enable_if::value, ValueRef>::type +ComputeGraph::add_scalar(T value) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(value); + return idx; +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/containers/Types.cpp b/backends/vulkan/runtime/graph/containers/Types.cpp index bbfde572b01..0779ed87164 100644 --- a/backends/vulkan/runtime/graph/containers/Types.cpp +++ b/backends/vulkan/runtime/graph/containers/Types.cpp @@ -12,20 +12,25 @@ namespace at { namespace native { namespace vulkan { +#define PRINT_CASE(name) \ + case TypeTag::name: \ + out << #name; \ + break; + std::ostream& operator<<(std::ostream& out, const TypeTag& tag) { switch (tag) { - case TypeTag::NONE: - out << "NONE"; - break; - case TypeTag::TENSOR: - out << "TENSOR"; - break; - case TypeTag::STAGING: - out << "STAGING"; - break; - default: - out << "UNKNOWN"; - break; + PRINT_CASE(NONE) + PRINT_CASE(INT) + PRINT_CASE(DOUBLE) + PRINT_CASE(BOOL) + PRINT_CASE(TENSOR) + PRINT_CASE(STAGING) + PRINT_CASE(TENSORREF) + PRINT_CASE(INTLIST) + PRINT_CASE(DOUBLELIST) + PRINT_CASE(BOOLLIST) + PRINT_CASE(VALUELIST) + PRINT_CASE(STRING) } return out; } diff --git a/backends/vulkan/runtime/graph/containers/Types.h b/backends/vulkan/runtime/graph/containers/Types.h index a7162d777ac..d5dee7ea0dd 100644 --- a/backends/vulkan/runtime/graph/containers/Types.h +++ b/backends/vulkan/runtime/graph/containers/Types.h @@ -23,12 +23,21 @@ namespace vulkan { */ enum class TypeTag : uint32_t { NONE, - TENSOR, - STAGING, - TENSORREF, + // Scalar types INT, DOUBLE, BOOL, + // Tensor and tensor adjacent types + TENSOR, + STAGING, + TENSORREF, + // Scalar lists + INTLIST, + DOUBLELIST, + BOOLLIST, + // Special Type + VALUELIST, + STRING, }; std::ostream& operator<<(std::ostream& out, const TypeTag& tag); diff --git a/backends/vulkan/runtime/graph/containers/Value.h b/backends/vulkan/runtime/graph/containers/Value.h index d56791b4fa8..82ba9417137 100644 --- a/backends/vulkan/runtime/graph/containers/Value.h +++ b/backends/vulkan/runtime/graph/containers/Value.h @@ -22,6 +22,19 @@ namespace at { namespace native { namespace vulkan { +using ValueRef = int32_t; + +constexpr ValueRef kDummyValueRef = -1; + +inline bool is_valid(ValueRef value_ref) { + return value_ref >= 0; +} + +struct IOValueRef { + ValueRef value; + ValueRef staging; +}; + /* * This class is modelled after c10::IValue; however, it is simplified and does * not support as many types. However, the core design is the same; it is a @@ -48,6 +61,17 @@ struct Value final { api::StorageBuffer as_staging; TensorRef as_tensorref; + std::vector as_int_list; + std::vector as_double_list; + std::vector as_bool_list; + + // The below is a special type that is used to represent a list of other + // values stored in the graph. One application of the type is to represent + // a list of tensors or a list of optional tensors. + std::vector as_value_list; + + std::string as_string; + Payload() : u() {} // NOLINTNEXTLINE ~Payload(){}; @@ -68,21 +92,48 @@ struct Value final { Value& operator=(Value&&) = delete; +#define CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(type_tag, member_name) \ + case type_tag: \ + payload.u.member_name = rhs.payload.u.member_name; \ + break; + +#define CASE_MOVE_MOVEABLE_TYPE(type_tag, type, member_name) \ + case type_tag: \ + new (&payload.member_name) type(std::move(rhs.payload.member_name)); \ + break; + Value(Value&& rhs) noexcept : tag(rhs.tag) { - if (rhs.isTensor()) { - new (&payload.as_tensor) vTensor(std::move(rhs.payload.as_tensor)); - } else if (rhs.isStaging()) { - new (&payload.as_staging) - api::StorageBuffer(std::move(rhs.payload.as_staging)); - } else if (rhs.isTensorRef()) { - payload.as_tensorref = std::move(rhs.payload.as_tensorref); - } else { - payload.u = rhs.payload.u; + switch (tag) { + // Scalar types + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::INT, as_int); + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::DOUBLE, as_double); + CASE_MOVE_TRIVIALLY_COPYABLE_TYPE(TypeTag::BOOL, as_bool); + // Tensor and tensor adjacent types + CASE_MOVE_MOVEABLE_TYPE(TypeTag::TENSOR, vTensor, as_tensor); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::STAGING, api::StorageBuffer, as_staging); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::TENSORREF, TensorRef, as_tensorref); + // Scalar lists + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::INTLIST, std::vector, as_int_list); + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::DOUBLELIST, std::vector, as_double_list); + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::BOOLLIST, std::vector, as_bool_list); + // Special types + CASE_MOVE_MOVEABLE_TYPE( + TypeTag::VALUELIST, std::vector, as_value_list); + CASE_MOVE_MOVEABLE_TYPE(TypeTag::STRING, std::string, as_string); + + case TypeTag::NONE: + clearToNone(); + break; } - tag = rhs.tag; rhs.clearToNone(); } +#undef CASE_MOVE_TRIVIALLY_COPYABLE_TYPE +#undef CASE_MOVE_MOVEABLE_TYPE + // // Accessors // @@ -96,77 +147,127 @@ struct Value final { // ~Value() { - if (this->isTensor()) { - payload.as_tensor.~vTensor(); - } else if (this->isStaging()) { - payload.as_staging.~StorageBuffer(); - } else if (this->isTensorRef()) { - payload.as_tensorref.~TensorRef(); + switch (tag) { + case TypeTag::TENSOR: + payload.as_tensor.~vTensor(); + break; + case TypeTag::STAGING: + payload.as_staging.~StorageBuffer(); + break; + case TypeTag::TENSORREF: + payload.as_tensorref.~TensorRef(); + break; + case TypeTag::INTLIST: + payload.as_int_list.~vector(); + break; + case TypeTag::DOUBLELIST: + payload.as_double_list.~vector(); + break; + case TypeTag::BOOLLIST: + payload.as_bool_list.~vector(); + break; + case TypeTag::VALUELIST: + payload.as_value_list.~vector(); + break; + case TypeTag::STRING: + payload.as_string.~basic_string(); + break; + // Manually list out the types so that if a type here is added later and + // not handled the compiler can catch it. + case TypeTag::NONE: + case TypeTag::INT: + case TypeTag::DOUBLE: + case TypeTag::BOOL: + break; } } - // - // Tensor - // - - explicit Value(vTensor&& t) : tag(TypeTag::TENSOR) { - new (&payload.as_tensor) vTensor(std::move(t)); - } - - inline bool isTensor() const { - return TypeTag::TENSOR == tag; - } - - inline vTensor& toTensor() { - VK_CHECK_COND( - isTensor(), - "Expected value to have type TENSOR, got ", - tag, - " instead."); - return payload.as_tensor; +#define SUPPORT_TRIVIALLY_COPYABLE_TYPE( \ + type, type_name, type_tag, member_name) \ + explicit Value(type t) : tag(type_tag) { \ + payload.u.member_name = t; \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline const type& to##type_name() const { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return payload.u.member_name; \ } - // - // Staging - // - - explicit Value(api::StorageBuffer&& t) : tag(TypeTag::STAGING) { - new (&payload.as_staging) api::StorageBuffer(std::move(t)); - } - - inline bool isStaging() const { - return TypeTag::STAGING == tag; - } - - inline api::StorageBuffer& toStaging() { - VK_CHECK_COND( - isStaging(), - "Expected value to have type STAGING, got ", - tag, - " instead."); - return payload.as_staging; - } - - // - // TensorRef - // - - explicit Value(TensorRef&& t) : tag(TypeTag::TENSORREF) { - payload.as_tensorref = std::move(t); - } - - inline bool isTensorRef() const { - return TypeTag::TENSORREF == tag; + SUPPORT_TRIVIALLY_COPYABLE_TYPE(int64_t, Int, TypeTag::INT, as_int); + SUPPORT_TRIVIALLY_COPYABLE_TYPE(double, Double, TypeTag::DOUBLE, as_double); + SUPPORT_TRIVIALLY_COPYABLE_TYPE(bool, Bool, TypeTag::BOOL, as_bool); + +#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE + +#define SUPPORT_TRIVIALLY_MOVEABLE_TYPE( \ + type, type_name, type_tag, member_name) \ + explicit Value(type&& t) : tag(type_tag) { \ + new (&payload.member_name) type(std::move(t)); \ + } \ + inline bool is##type_name() const { \ + return tag == type_tag; \ + } \ + inline type& to##type_name() { \ + VK_CHECK_COND( \ + is##type_name(), \ + "Expected value to have type " #type_name ", got ", \ + tag, \ + " instead."); \ + return payload.member_name; \ } - inline TensorRef& toTensorRef() { - VK_CHECK_COND( - isTensorRef(), - "Expected value to have type TENSORREF, got ", - tag, - " instead."); - return payload.as_tensorref; - } + SUPPORT_TRIVIALLY_MOVEABLE_TYPE(vTensor, Tensor, TypeTag::TENSOR, as_tensor); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + api::StorageBuffer, + Staging, + TypeTag::STAGING, + as_staging); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + TensorRef, + TensorRef, + TypeTag::TENSORREF, + as_tensorref); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + IntList, + TypeTag::INTLIST, + as_int_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + DoubleList, + TypeTag::DOUBLELIST, + as_double_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + BoolList, + TypeTag::BOOLLIST, + as_bool_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::vector, + ValueList, + TypeTag::VALUELIST, + as_value_list); + + SUPPORT_TRIVIALLY_MOVEABLE_TYPE( + std::string, + String, + TypeTag::STRING, + as_string); + +#undef SUPPORT_TRIVIALLY_COPYABLE_TYPE +#undef SUPPORT_TRIVIALLY_MOVEABLE_TYPE private: Payload payload; @@ -177,18 +278,11 @@ struct Value final { // inline void clearToNone() noexcept { - payload.u.as_int = 0; + payload.u.as_int = -1; tag = TypeTag::NONE; } }; -using ValueRef = int32_t; - -struct IOValueRef { - ValueRef value; - ValueRef staging; -}; - } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/OpUtils.h b/backends/vulkan/runtime/graph/ops/OpUtils.h index 2a98337721c..b5acb3945ac 100644 --- a/backends/vulkan/runtime/graph/ops/OpUtils.h +++ b/backends/vulkan/runtime/graph/ops/OpUtils.h @@ -12,6 +12,8 @@ #include +#include + namespace at { namespace native { namespace vulkan { @@ -80,6 +82,20 @@ uint32_t dim_at(const vTensor& v_in) { api::utils::uvec3 adaptive_work_group_size( const api::utils::uvec3& global_work_group); +template +T extract_scalar(const Value& value) { + if (value.isInt()) { + return static_cast(value.toInt()); + } + if (value.isDouble()) { + return static_cast(value.toDouble()); + } + if (value.isBool()) { + return static_cast(value.toBool()); + } + VK_THROW("Cannot extract scalar from Value with type ", value.type()); +} + } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h index c11aa0168e1..06245d889e7 100644 --- a/backends/vulkan/runtime/graph/ops/OperatorRegistry.h +++ b/backends/vulkan/runtime/graph/ops/OperatorRegistry.h @@ -19,11 +19,8 @@ namespace at { namespace native { namespace vulkan { -using OpFunction = const std::function&)>; // TODO: Generalize to - // support float, - // int64_t. +using OpFunction = + const std::function&)>; bool hasOpsFn(const std::string& name); diff --git a/backends/vulkan/runtime/graph/ops/Utils.h b/backends/vulkan/runtime/graph/ops/Utils.h index 9d6153e1d1d..918318178b3 100644 --- a/backends/vulkan/runtime/graph/ops/Utils.h +++ b/backends/vulkan/runtime/graph/ops/Utils.h @@ -17,7 +17,7 @@ namespace native { namespace vulkan { #define DECLARE_OP_FN(function) \ - ValueRef function(ComputeGraph& graph, const std::vector& args); + void function(ComputeGraph& graph, const std::vector& args); api::utils::ivec4 get_size_as_ivec4(const vTensor& t); diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp index d635ea9a7f1..f5895c1544e 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp @@ -16,42 +16,35 @@ namespace at { namespace native { namespace vulkan { +#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \ + void function(ComputeGraph& graph, const std::vector& args) { \ + return add_arithmetic_node( \ + graph, args[0], args[1], args[2], args[3], VK_KERNEL(shader)); \ + } + #define DEFINE_ARITHMETIC_FN(function, shader) \ - ValueRef function(ComputeGraph& graph, const std::vector& args) { \ + void function(ComputeGraph& graph, const std::vector& args) { \ return add_arithmetic_node( \ - graph, args[0], args[1], args[2], VK_KERNEL(shader), args[3]); \ + graph, args[0], args[1], kDummyValueRef, args[2], VK_KERNEL(shader)); \ } -DEFINE_ARITHMETIC_FN(add, add); -DEFINE_ARITHMETIC_FN(sub, sub); +DEFINE_ARITHMETIC_WITH_ALPHA_FN(add, add); +DEFINE_ARITHMETIC_WITH_ALPHA_FN(sub, sub); + +// Floor div does not have an alpha, but a string argument (which is unused) is +// passed in at the same location as the alpha argument in other op. +DEFINE_ARITHMETIC_WITH_ALPHA_FN(floor_div, floor_divide); + DEFINE_ARITHMETIC_FN(mul, mul); DEFINE_ARITHMETIC_FN(div, div); -DEFINE_ARITHMETIC_FN(floor_div, floor_divide); DEFINE_ARITHMETIC_FN(pow, pow); -// TODO(T180908843): Bypass this entrypoint function by creating `ValueRef out` -// ahead of time. -ValueRef add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const float alpha, - const api::ShaderInfo& shader, - const int64_t shared_object_idx) { - std::vector in1_sizes = graph.get_val_sizes(in1); - api::ScalarType in1_dtype = graph.get_val_dtype(in1); - - ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx); - add_arithmetic_node(graph, in1, in2, out, alpha, shader); - return out; -} - void add_arithmetic_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, + const ValueRef alpha, const ValueRef out, - const float alpha, const api::ShaderInfo& shader) { ValueRef arg1 = prepack_if_tensor_ref(graph, in1); ValueRef arg2 = prepack_if_tensor_ref(graph, in2); @@ -63,11 +56,18 @@ void add_arithmetic_node( api::utils::uvec3 global_size = t_out.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + float alpha_val = 1.0f; + // String is checked since floor_div passes in an unused string argument in + // place of alpha + if (is_valid(alpha) && !graph.get_val(alpha).isString()) { + alpha_val = extract_scalar(graph.get_val(alpha)); + } + ArithmeticParams block{ get_size_as_ivec4(t_out), get_size_as_ivec4(t_in1), get_size_as_ivec4(t_in2), - 1.0, + alpha_val, }; api::UniformParamsBuffer params(graph.context(), block); diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h index 8017f6c4c45..3ef3cb3e426 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h +++ b/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h @@ -25,20 +25,12 @@ DECLARE_OP_FN(div); DECLARE_OP_FN(floor_div); DECLARE_OP_FN(pow); -ValueRef add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const float alpha, - const api::ShaderInfo& shader, - const int64_t shared_object_idx = -1); - void add_arithmetic_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, + const ValueRef alpha, const ValueRef out, - const float alpha, const api::ShaderInfo& shader); struct ArithmeticParams final { diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index 3d8dab9a2fd..e5139b5fd53 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -26,8 +26,55 @@ table VkTensor { mem_obj_id:int; } +table Null {} + +table Int { + int_val:long; +} + +table Bool { + bool_val:bool; +} + +table Double { + double_val:double; +} + +table String { + string_val:string; +} + +table IntList { + items:[long]; +} + +table DoubleList { + items:[double]; +} + +table BoolList { + items:[bool]; +} + +table ValueList { + items:[int]; +} + +union GraphTypes { + Null, + Int, + Double, + Bool, + VkTensor, + IntList, + DoubleList, + BoolList, + ValueList, + String, +} + table VkValue { - value:VkTensor; + value:GraphTypes; } // Abstraction to represent a region of bytes in a raw data buffer. Useful for referencing raw data diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 68e54c2bc3b..572ef018bc2 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -15,6 +15,9 @@ from torch.export import ExportedProgram from torch.fx import Node +_ScalarType = Union[int, bool, float] +_Argument = Union[torch.fx.Node, int, bool, float, str] + class VkGraphBuilder: def __init__(self, program: ExportedProgram) -> None: @@ -106,7 +109,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int: return const_buffer_idx - def create_single_vk_value(self, node: Node) -> int: + def create_single_tensor_value(self, node: Node) -> int: constant_id = self.maybe_add_constant_tensor(node) spec = node.meta.get("spec") @@ -138,17 +141,48 @@ def create_single_vk_value(self, node: Node) -> int: ) return new_id - def create_vk_values_for(self, node: Node): + def create_tensor_values(self, node: Node) -> int: spec = node.meta.get("spec") if isinstance(spec, TensorSpec): - return self.create_single_vk_value(node) + return self.create_single_tensor_value(node) else: raise RuntimeError( "Creating values for nodes with collection types is not supported yet." ) + def create_scalar_value(self, scalar: _ScalarType) -> int: + new_id = len(self.values) + if isinstance(scalar, int): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) + if isinstance(scalar, float): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) + if isinstance(scalar, bool): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) + return new_id + + def create_string_value(self, string: str) -> int: + new_id = len(self.values) + self.values.append( + vk_graph_schema.VkValue(vk_graph_schema.String(string_val=string)) + ) + return new_id + + def get_or_create_value_for(self, arg: _Argument): + if isinstance(arg, torch.fx.Node): + # If the value has already been created, return the existing id + if arg in self.node_to_value_ids: + return self.node_to_value_ids[arg] + # Return id for a newly created value + return self.create_tensor_values(arg) + elif isinstance(arg, (int, float, bool)): + return self.create_scalar_value(arg) + elif isinstance(arg, str): + return self.create_string_value(arg) + else: + raise RuntimeError(f"Cannot create value for arg of type {type(arg)}") + def process_placeholder_node(self, node: Node) -> None: - ids = self.create_vk_values_for(node) + ids = self.create_tensor_values(node) if not self.is_param_node(node): if isinstance(ids, int): self.input_ids.append(ids) @@ -156,27 +190,32 @@ def process_placeholder_node(self, node: Node) -> None: self.input_ids += ids def process_call_function_node(self, node) -> None: - args = [] - # Add input nodes - for inp_node in node.all_input_nodes: - if inp_node not in self.node_to_value_ids: - raise AssertionError( - "Cannot find input to current node in node_to_value_ids. This means " - "this node is being serialized before its input which is not allowed." - ) - args.append(self.node_to_value_ids[inp_node]) + operator_call_args = [] + + for i, schema_arg in enumerate(node.target._schema.arguments): + if not schema_arg.kwarg_only and i < len(node.args): + function_arg = node.args[i] + elif schema_arg.name in node.kwargs: + function_arg = node.kwargs[schema_arg.name] + else: + function_arg = schema_arg.default_value + + # Create a value for each function argument. If the argument has been + # previously encountered, then use the existing value id. + operator_call_args.append(self.get_or_create_value_for(function_arg)) + # Add output node - args.append(self.create_vk_values_for(node)) + operator_call_args.append(self.create_tensor_values(node)) self.chain.append( vk_graph_schema.OperatorCall( name=node.target.__name__, - args=args, + args=operator_call_args, ), ) def process_getattr_node(self, node: Node) -> None: - self.create_vk_values_for(node) + self.create_tensor_values(node) def process_output_node(self, node: Node) -> None: if node.all_input_nodes[0] not in self.node_to_value_ids: diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index eeb1589a2a4..1c5a05727b0 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import IntEnum -from typing import List +from typing import List, Union @dataclass @@ -34,13 +34,67 @@ class VkTensor: @dataclass -class VkScalar: +class Null: pass +@dataclass +class Int: + int_val: int + + +@dataclass +class Bool: + bool_val: bool + + +@dataclass +class Double: + double_val: float + + +@dataclass +class IntList: + items: List[int] + + +@dataclass +class DoubleList: + items: List[float] + + +@dataclass +class BoolList: + items: List[bool] + + +@dataclass +class ValueList: + items: List[int] + + +@dataclass +class String: + string_val: str + + +GraphTypes = Union[ + Null, + Int, + Double, + Bool, + VkTensor, + IntList, + BoolList, + DoubleList, + ValueList, + String, +] + + @dataclass class VkValue: - value: VkTensor + value: "GraphTypes" @dataclass diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 51b58720c3f..c53444ff0bd 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -427,6 +427,60 @@ TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { graph.get_val(name.value).toTensor().gpu_numel()); \ graph.copy_from_staging(name.staging, data_##name.data(), data_##name.size()); +TEST(VulkanComputeGraphTest, test_values_scalars) { + GraphConfig config = generate_graph_config(); + ComputeGraph graph(config); + + ValueRef idx; + + idx = graph.add_scalar(4); + EXPECT_TRUE(graph.get_val(idx).toInt() == 4); + + idx = graph.add_scalar(5.5f); + EXPECT_TRUE(graph.get_val(idx).toDouble() == 5.5f); +} + +TEST(VulkanComputeGraphTest, test_values_scalar_list_inplace_constructed) { + GraphConfig config = generate_graph_config(); + ComputeGraph graph(config); + + ValueRef idx = graph.add_scalar_list({1, 2, 3, 4}); + std::vector& arr = graph.get_val(idx).toIntList(); + EXPECT_TRUE(arr.size() == 4); + for (int i = 0; i < 4; i++) { + EXPECT_TRUE(arr[i] == i + 1); + } +} + +TEST(VulkanComputeGraphTest, test_values_scalar_list_outside_constructed) { + GraphConfig config = generate_graph_config(); + ComputeGraph graph(config); + + ValueRef idx; + { + std::vector data = {5.0, 4.0, 3.0, 2.0, 1.0}; + idx = graph.add_scalar_list(std::move(data)); + } + std::vector& arr = graph.get_val(idx).toDoubleList(); + EXPECT_TRUE(arr.size() == 5); + for (int i = 0; i < 5; i++) { + EXPECT_TRUE(arr[i] == (5 - i)); + } +} + +TEST(VulkanComputeGraphTest, test_values_string) { + GraphConfig config = generate_graph_config(); + ComputeGraph graph(config); + + ValueRef idx; + { + std::string data = "hello, world"; + idx = graph.add_string(std::move(data)); + } + std::string& stored = graph.get_val(idx).toString(); + EXPECT_TRUE(stored == "hello, world"); +} + TEST(VulkanComputeGraphTest, test_simple_graph) { GraphConfig config = generate_graph_config(); ComputeGraph graph(config); @@ -441,7 +495,10 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { IOValueRef out = {}; - out.value = add_arithmetic_node(graph, a.value, b.value, 1.0, VK_KERNEL(add)); + out.value = graph.add_tensor(size_big, api::kFloat); + + add_arithmetic_node( + graph, a.value, b.value, kDummyValueRef, out.value, VK_KERNEL(add)); out.staging = graph.set_output_tensor(out.value); @@ -487,8 +544,11 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { IOValueRef a = graph.add_input_tensor(size_big, api::kFloat); - ValueRef c = add_arithmetic_node(graph, a.value, w1, 1.0, VK_KERNEL(add)); - ValueRef e = add_arithmetic_node(graph, c, w2, 1.0, VK_KERNEL(mul)); + ValueRef c = graph.add_tensor(size_big, api::kFloat); + ValueRef e = graph.add_tensor(size_big, api::kFloat); + + add_arithmetic_node(graph, a.value, w1, kDummyValueRef, c, VK_KERNEL(add)); + add_arithmetic_node(graph, c, w2, kDummyValueRef, e, VK_KERNEL(mul)); IOValueRef out = {}; out.value = e; @@ -541,14 +601,14 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { // 1 staging buffer for each input tensor EXPECT_TRUE(get_vma_allocation_count() == 4); - ValueRef c = add_arithmetic_node( - graph, - a.value, - b.value, - 1.0, - VK_KERNEL(add), + ValueRef c = graph.add_tensor( + size_big, + api::kFloat, /*shared_object_idx = */ 6); + add_arithmetic_node( + graph, a.value, b.value, kDummyValueRef, c, VK_KERNEL(add)); + IOValueRef d = graph.add_input_tensor( size_small, api::kFloat, @@ -560,14 +620,13 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { // 1 staging buffer for the input tensor EXPECT_TRUE(get_vma_allocation_count() == 7); - ValueRef e = add_arithmetic_node( - graph, - c, - d.value, - 1.0, - VK_KERNEL(mul), + ValueRef e = graph.add_tensor( + size_big, + api::kFloat, /*shared_object_idx = */ 4); + add_arithmetic_node(graph, c, d.value, kDummyValueRef, e, VK_KERNEL(mul)); + IOValueRef out = {}; out.value = e; out.staging = graph.set_output_tensor(out.value);