diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 0507b679e13..d84d893540c 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -32,7 +32,8 @@ PrepackNode::PrepackNode( const ValueRef tref, const ValueRef packed, const vkapi::ParamsBindList& params, - const vkapi::SpecVarList& spec_vars) + const vkapi::SpecVarList& spec_vars, + const std::vector& push_constants) : shader_(shader), noop_shader_(get_noop_shader(graph, packed)), global_workgroup_size_(global_workgroup_size), @@ -40,7 +41,8 @@ PrepackNode::PrepackNode( tref_(tref), packed_(packed), params_(params), - spec_vars_(spec_vars) { + spec_vars_(spec_vars), + push_constants_(push_constants) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); } @@ -75,10 +77,20 @@ void PrepackNode::encode(ComputeGraph* graph) { std::unique_lock cmd_lock = context->dispatch_lock(); + std::array push_constants_data; + uint32_t push_constants_offset = 0; + + for (const auto& push_constant : push_constants_) { + push_constants_offset += push_constant.write( + push_constants_data.data(), + push_constants_offset, + kMaxPushConstantSize); + } + { vkapi::PipelineBarrier pipeline_barrier{}; vkapi::DescriptorSet descriptor_set = context->get_descriptor_set( - shader_, local_workgroup_size_, spec_vars_, 0u); + shader_, local_workgroup_size_, spec_vars_, push_constants_offset); uint32_t idx = 0; bind_tensor_to_descriptor_set( @@ -91,7 +103,12 @@ void PrepackNode::encode(ComputeGraph* graph) { bind_params_to_descriptor_set(params_, descriptor_set, idx); context->register_shader_dispatch( - descriptor_set, pipeline_barrier, shader_, global_workgroup_size_); + descriptor_set, + pipeline_barrier, + shader_, + global_workgroup_size_, + push_constants_data.data(), + push_constants_offset); } // Submit a compute shader that performs a no-op with the packed tensor in diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 2d194e7f6a0..a45deb9ff70 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -10,6 +10,7 @@ #include +#include #include namespace vkcompute { @@ -34,7 +35,8 @@ class PrepackNode final { const ValueRef tref, const ValueRef packed, const vkapi::ParamsBindList& params, - const vkapi::SpecVarList& spec_vars = {}); + const vkapi::SpecVarList& spec_vars = {}, + const std::vector& push_constants = {}); ~PrepackNode() = default; @@ -54,6 +56,7 @@ class PrepackNode final { const ValueRef packed_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; + const std::vector push_constants_; private: api::StagingBuffer create_staging_buffer(ComputeGraph* graph);