diff --git a/backends/vulkan/runtime/graph/containers/PushConstantData.cpp b/backends/vulkan/runtime/graph/containers/PushConstantData.cpp new file mode 100644 index 00000000000..7999118443b --- /dev/null +++ b/backends/vulkan/runtime/graph/containers/PushConstantData.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace vkcompute { + +uint32_t PushConstantDataInfo::write( + void* dst, + const uint32_t dst_offset, + const uint32_t max_dst_size) const { + if (tensorUniformData != nullptr) { + return tensorUniformData->write_attribute( + dst, dst_offset, max_dst_size, payload_.attr); + } + + VK_CHECK_COND( + (dst_offset + payload_.dataSize) <= max_dst_size, + "Attempting to write push constant data outside data boundary."); + memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize); + return payload_.dataSize; +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/containers/PushConstantData.h b/backends/vulkan/runtime/graph/containers/PushConstantData.h new file mode 100644 index 00000000000..39cde4722a7 --- /dev/null +++ b/backends/vulkan/runtime/graph/containers/PushConstantData.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +class ComputeGraph; + +constexpr uint32_t kMaxPushConstantSize = 128; +/* + * Represents a push constant data entry + * Which is either shared pointer to a tensor's uniform data with an attribute + * Or data with a maximum size of 16 bytes + */ +class PushConstantDataInfo { + std::shared_ptr tensorUniformData; + union Payload { + struct { + api::vTensor::Attribute attr; + }; + struct { + uint8_t data[16]; + uint32_t dataSize; + }; + }; + + Payload payload_; + + public: + explicit PushConstantDataInfo( + const std::shared_ptr& tensorUniformData, + api::vTensor::Attribute attr) + : tensorUniformData(tensorUniformData) { + payload_.attr = attr; + } + + explicit PushConstantDataInfo( + const void* data, + uint32_t dataLen, + uint32_t pushConstantLen = 0) + : tensorUniformData(nullptr) { + VK_CHECK_COND( + dataLen <= 16, "Single push constant data size must be <= 16 bytes"); + payload_.dataSize = pushConstantLen ? pushConstantLen : dataLen; + memcpy(payload_.data, data, dataLen); + } + + /* + * Function writes push constant data to the destination buffer + */ + uint32_t write( + void* dst, + const uint32_t dst_offset, + const uint32_t max_dst_size) const; +}; + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 63b8798f2c1..6730d851483 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -14,22 +14,6 @@ namespace vkcompute { -uint32_t PushConstantDataInfo::write( - void* dst, - const uint32_t dst_offset, - const uint32_t max_dst_size) const { - if (tensorUniformData != nullptr) { - return tensorUniformData->write_attribute( - dst, dst_offset, max_dst_size, payload_.attr); - } - - VK_CHECK_COND( - (dst_offset + payload_.dataSize) <= max_dst_size, - "Attempting to write push constant data outside data boundary."); - memcpy((uint8_t*)dst + dst_offset, payload_.data, payload_.dataSize); - return payload_.dataSize; -} - DispatchNode::DispatchNode( ComputeGraph& graph, const vkapi::ShaderInfo& shader, diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index 4661b5bf9cf..e3794e9a9e4 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -10,6 +10,7 @@ #include +#include #include #include @@ -18,54 +19,6 @@ namespace vkcompute { class ComputeGraph; -constexpr uint32_t kMaxPushConstantSize = 128; -/* - * Represents a push constant data entry - * Which is either shared pointer to a tensor's uniform data with an attribute - * Or data with a maximum size of 16 bytes - */ -class PushConstantDataInfo { - std::shared_ptr tensorUniformData; - union Payload { - struct { - api::vTensor::Attribute attr; - }; - struct { - uint8_t data[16]; - uint32_t dataSize; - }; - }; - - Payload payload_; - - public: - explicit PushConstantDataInfo( - const std::shared_ptr& tensorUniformData, - api::vTensor::Attribute attr) - : tensorUniformData(tensorUniformData) { - payload_.attr = attr; - } - - explicit PushConstantDataInfo( - const void* data, - uint32_t dataLen, - uint32_t pushConstantLen = 0) - : tensorUniformData(nullptr) { - VK_CHECK_COND( - dataLen <= 16, "Single push constant data size must be <= 16 bytes"); - payload_.dataSize = pushConstantLen ? pushConstantLen : dataLen; - memcpy(payload_.data, data, dataLen); - } - - /* - * Function writes push constant data to the destination buffer - */ - uint32_t write( - void* dst, - const uint32_t dst_offset, - const uint32_t max_dst_size) const; -}; - /* * Represents a single shader execution op in a ML model. */ diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 0507b679e13..d84d893540c 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -32,7 +32,8 @@ PrepackNode::PrepackNode( const ValueRef tref, const ValueRef packed, const vkapi::ParamsBindList& params, - const vkapi::SpecVarList& spec_vars) + const vkapi::SpecVarList& spec_vars, + const std::vector& push_constants) : shader_(shader), noop_shader_(get_noop_shader(graph, packed)), global_workgroup_size_(global_workgroup_size), @@ -40,7 +41,8 @@ PrepackNode::PrepackNode( tref_(tref), packed_(packed), params_(params), - spec_vars_(spec_vars) { + spec_vars_(spec_vars), + push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); } @@ -75,10 +77,20 @@ void PrepackNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); + std::array push_constants_data; + uint32_t push_constants_offset = 0; + + for (const auto& push_constant : push_constants_) { + push_constants_offset += push_constant.write( + push_constants_data.data(), + push_constants_offset, + kMaxPushConstantSize); + } + { vkapi::PipelineBarrier pipeline_barrier{}; vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( - shader_, local_workgroup_size_, spec_vars_, 0u); + shader_, local_workgroup_size_, spec_vars_, push_constants_offset); uint32_t idx = 0; bind_tensor_to_descriptor_set( @@ -91,7 +103,12 @@ void PrepackNode::encode(ComputeGraph* graph) { bind_params_to_descriptor_set(params_, descriptor_set, idx); context->register_shader_dispatch( - descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); + descriptor_set, + pipeline_barrier, + shader_, + global_workgroup_size_, + push_constants_data.data(), + push_constants_offset); } // Submit a compute shader that performs a no-op with the packed tensor in diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 2d194e7f6a0..a45deb9ff70 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -10,6 +10,7 @@ #include +#include #include namespace vkcompute { @@ -34,7 +35,8 @@ class PrepackNode final { const ValueRef tref, const ValueRef packed, const vkapi::ParamsBindList& params, - const vkapi::SpecVarList& spec_vars = {}); + const vkapi::SpecVarList& spec_vars = {}, + const std::vector& push_constants = {}); ~PrepackNode() = default; @@ -54,6 +56,7 @@ class PrepackNode final { const ValueRef packed_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; + const std::vector push_constants_; private: api::StagingBuffer create_staging_buffer(ComputeGraph* graph);