Skip to content

Commit 7f04bde

Browse files
committed
[ET-VK] Adding PushConstantDataInfo input to PrepackNode class.
Pull Request resolved: #8649 This diff adds a new input to the PrepackNode class called PushConstantDataInfo. This input is used to pass push constant data to the shader. ghstack-source-id: 268084835 @exported-using-ghexport Differential Revision: [D70102043](https://our.internmc.facebook.com/intern/diff/D70102043/)
1 parent 5c13a90 commit 7f04bde

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,17 @@ PrepackNode::PrepackNode(
3232
const ValueRef tref,
3333
const ValueRef packed,
3434
const vkapi::ParamsBindList& params,
35-
const vkapi::SpecVarList& spec_vars)
35+
const vkapi::SpecVarList& spec_vars,
36+
const std::vector<PushConstantDataInfo>& push_constants)
3637
: shader_(shader),
3738
noop_shader_(get_noop_shader(graph, packed)),
3839
global_workgroup_size_(global_workgroup_size),
3940
local_workgroup_size_(local_workgroup_size),
4041
tref_(tref),
4142
packed_(packed),
4243
params_(params),
43-
spec_vars_(spec_vars) {
44+
spec_vars_(spec_vars),
45+
push_constants_(push_constants) {
4446
graph.update_descriptor_counts(shader, /*execute = */ false);
4547
graph.update_descriptor_counts(noop_shader_, /*execute = */ false);
4648
}
@@ -75,10 +77,20 @@ void PrepackNode::encode(ComputeGraph* graph) {
7577

7678
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
7779

80+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
81+
uint32_t push_constants_offset = 0;
82+
83+
for (const auto& push_constant : push_constants_) {
84+
push_constants_offset += push_constant.write(
85+
push_constants_data.data(),
86+
push_constants_offset,
87+
kMaxPushConstantSize);
88+
}
89+
7890
{
7991
vkapi::PipelineBarrier pipeline_barrier{};
8092
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
81-
shader_, local_workgroup_size_, spec_vars_, 0u);
93+
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
8294

8395
uint32_t idx = 0;
8496
bind_tensor_to_descriptor_set(
@@ -91,7 +103,12 @@ void PrepackNode::encode(ComputeGraph* graph) {
91103
bind_params_to_descriptor_set(params_, descriptor_set, idx);
92104

93105
context->register_shader_dispatch(
94-
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
106+
descriptor_set,
107+
pipeline_barrier,
108+
shader_,
109+
global_workgroup_size_,
110+
push_constants_data.data(),
111+
push_constants_offset);
95112
}
96113

97114
// Submit a compute shader that performs a no-op with the packed tensor in

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/backends/vulkan/runtime/api/api.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/containers/PushConstantData.h>
1314
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
1415

1516
namespace vkcompute {
@@ -34,7 +35,8 @@ class PrepackNode final {
3435
const ValueRef tref,
3536
const ValueRef packed,
3637
const vkapi::ParamsBindList& params,
37-
const vkapi::SpecVarList& spec_vars = {});
38+
const vkapi::SpecVarList& spec_vars = {},
39+
const std::vector<PushConstantDataInfo>& push_constants = {});
3840

3941
~PrepackNode() = default;
4042

@@ -54,6 +56,7 @@ class PrepackNode final {
5456
const ValueRef packed_;
5557
const vkapi::ParamsBindList params_;
5658
const vkapi::SpecVarList spec_vars_;
59+
const std::vector<PushConstantDataInfo> push_constants_;
5760

5861
private:
5962
api::StagingBuffer create_staging_buffer(ComputeGraph* graph);

0 commit comments

Comments
 (0)