Skip to content

Commit d3e43de

Browse files
committed
[ET-VK] Support multiple UniformParamsBuffer
Before: Each node contains a `UniformParamsBuffer`. After: Each node contains a `std::vector<std::shared_ptr<UniformParamsBuffer>>`. In follow up changes, we will break up parameters to be passed via multiple UniformParamsBuffer, since 1. some are tensor-specific (e.g. image extents) and 2. others are operator-specific (e.g. alpha for binary ops). Hence, we need **`std::vector`**. We are adding the methods for #1 in #2340. Since #1 and #2 will be owned by different objects, we need **pointers**. Since #1 is owned by `vTensor` which is non-copyable, we can't use unique_ptr so we need **`std::shared_ptr`**. Differential Revision: [D54691831](https://our.internmc.facebook.com/intern/diff/D54691831/) [ghstack-poisoned]
1 parent 2a96ea8 commit d3e43de

File tree

9 files changed

+62
-34
lines changed

9 files changed

+62
-34
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,12 @@ class ComputeGraph final {
154154
ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true);
155155
ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true);
156156

157+
template <typename Block>
158+
inline std::shared_ptr<api::UniformParamsBuffer> create_params_buffer(
159+
const Block& data) {
160+
return std::make_shared<api::UniformParamsBuffer>(context_.get(), data);
161+
}
162+
157163
/*
158164
* Convenience function to add an input tensor along with its staging buffer
159165
*/

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ ExecuteNode::ExecuteNode(
2222
const api::utils::uvec3& global_workgroup_size,
2323
const api::utils::uvec3& local_workgroup_size,
2424
const std::vector<ArgGroup>& args,
25-
api::UniformParamsBuffer&& params)
25+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
2626
: shader_(shader),
2727
global_workgroup_size_(global_workgroup_size),
2828
local_workgroup_size_(local_workgroup_size),
2929
args_(args),
30-
params_(std::move(params)) {
30+
params_(params) {
3131
graph.update_descriptor_counts(shader, /*execute = */ true);
3232
}
3333

@@ -43,7 +43,7 @@ void ExecuteNode::encode(ComputeGraph* graph) {
4343
uint32_t idx = 0;
4444
idx = bind_values_to_descriptor_set(
4545
graph, args_, pipeline_barrier, descriptor_set, idx);
46-
descriptor_set.bind(idx, params_.buffer());
46+
bind_params_to_descriptor_set(params_, descriptor_set, idx);
4747

4848
context->register_shader_dispatch(
4949
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class ExecuteNode final {
5353
const api::utils::uvec3& global_workgroup_size,
5454
const api::utils::uvec3& local_workgroup_size,
5555
const std::vector<ArgGroup>& args,
56-
api::UniformParamsBuffer&& params);
56+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);
5757

5858
~ExecuteNode() = default;
5959

@@ -64,9 +64,8 @@ class ExecuteNode final {
6464
const api::utils::uvec3 global_workgroup_size_;
6565
const api::utils::uvec3 local_workgroup_size_;
6666
const std::vector<ArgGroup> args_;
67-
// TODO(T180906086): pass multiple buffers and index with ValueRef.
6867
// TODO(T180906457): allow re-computing param buffers.
69-
api::UniformParamsBuffer params_;
68+
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
7069
};
7170

7271
} // namespace vulkan

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ PrepackNode::PrepackNode(
2424
const api::utils::uvec3& local_workgroup_size,
2525
const ValueRef tref,
2626
const ValueRef packed,
27-
api::UniformParamsBuffer&& params)
27+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params)
2828
: shader_(shader),
2929
global_workgroup_size_(global_workgroup_size),
3030
local_workgroup_size_(local_workgroup_size),
3131
tref_(tref),
3232
packed_(packed),
33-
params_(std::move(params)) {
33+
params_(params) {
3434
graph.update_descriptor_counts(shader, /*execute = */ false);
3535
}
3636

@@ -61,7 +61,7 @@ void PrepackNode::encode(ComputeGraph* graph) {
6161
descriptor_set,
6262
idx++);
6363
bind_staging_to_descriptor_set(staging, descriptor_set, idx++);
64-
descriptor_set.bind(idx, params_.buffer());
64+
bind_params_to_descriptor_set(params_, descriptor_set, idx);
6565

6666
context->register_shader_dispatch(
6767
descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);

backends/vulkan/runtime/graph/ops/PrepackNode.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class PrepackNode final {
3737
const api::utils::uvec3& local_workgroup_size,
3838
const ValueRef tref,
3939
const ValueRef packed,
40-
api::UniformParamsBuffer&& params);
40+
const std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params);
4141

4242
~PrepackNode() = default;
4343

@@ -49,9 +49,8 @@ class PrepackNode final {
4949
const api::utils::uvec3 local_workgroup_size_;
5050
const ValueRef tref_;
5151
const ValueRef packed_;
52-
// TODO(T180906086): pass multiple buffers and index with ValueRef.
5352
// TODO(T180906457): allow re-computing param buffers.
54-
api::UniformParamsBuffer params_;
53+
std::vector<std::shared_ptr<api::UniformParamsBuffer>> params_;
5554
};
5655

5756
} // namespace vulkan

backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ void add_arithmetic_node(
7272
get_size_as_ivec4(t_in2),
7373
alpha_val,
7474
};
75-
api::UniformParamsBuffer params(graph.context(), block);
7675

7776
graph.execute_nodes().emplace_back(new ExecuteNode(
7877
graph,
@@ -81,7 +80,7 @@ void add_arithmetic_node(
8180
local_size,
8281
{{out, api::MemoryAccessType::WRITE},
8382
{{arg1, arg2}, api::MemoryAccessType::READ}},
84-
std::move(params)));
83+
{graph.create_params_buffer(block)}));
8584
}
8685

8786
REGISTER_OPERATORS {

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,14 @@ void add_staging_to_tensor_node(
4545
api::utils::uvec3 global_size = t_out.extents();
4646
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
4747

48-
api::UniformParamsBuffer params(
49-
graph.context(), create_staging_params(t_out));
50-
5148
graph.execute_nodes().emplace_back(new ExecuteNode(
5249
graph,
5350
shader,
5451
global_size,
5552
local_size,
5653
{{out_tensor, api::MemoryAccessType::WRITE},
5754
{in_staging, api::MemoryAccessType::READ}},
58-
std::move(params)));
55+
{graph.create_params_buffer(create_staging_params(t_out))}));
5956
}
6057

6158
void add_tensor_to_staging_node(
@@ -71,7 +68,6 @@ void add_tensor_to_staging_node(
7168
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
7269

7370
StagingParams sp = create_staging_params(t_in);
74-
api::UniformParamsBuffer params(graph.context(), sp);
7571

7672
// TODO(T181194784): These are workgroup sizes for special cases. Refactor the
7773
// calculation of workgroup sizes to a standalone function. We should use
@@ -98,7 +94,7 @@ void add_tensor_to_staging_node(
9894
local_size,
9995
{{in_tensor, api::MemoryAccessType::READ},
10096
{out_staging, api::MemoryAccessType::WRITE}},
101-
std::move(params)));
97+
{graph.create_params_buffer(sp)}));
10298
}
10399

104100
ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
@@ -112,10 +108,15 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) {
112108
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
113109

114110
StagingParams sp = create_staging_params(t);
115-
api::UniformParamsBuffer params(graph.context(), sp);
116111

117112
graph.prepack_nodes().emplace_back(new PrepackNode(
118-
graph, shader, global_size, local_size, vref, v, std::move(params)));
113+
graph,
114+
shader,
115+
global_size,
116+
local_size,
117+
vref,
118+
v,
119+
{graph.create_params_buffer(sp)}));
119120

120121
return v;
121122
}

backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@ void bind_tensor_to_descriptor_set(
2929
}
3030
}
3131

32-
void bind_staging_to_descriptor_set(
33-
api::StorageBuffer& staging,
34-
api::DescriptorSet& descriptor_set,
35-
const uint32_t idx) {
36-
descriptor_set.bind(idx, staging.buffer());
37-
}
38-
3932
uint32_t bind_values_to_descriptor_set(
4033
ComputeGraph* graph,
4134
const std::vector<ArgGroup>& args,
@@ -63,6 +56,24 @@ uint32_t bind_values_to_descriptor_set(
6356
return idx;
6457
}
6558

59+
uint32_t bind_params_to_descriptor_set(
60+
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
61+
api::DescriptorSet& descriptor_set,
62+
const uint32_t base_idx) {
63+
uint32_t idx = base_idx;
64+
for (auto& param : params) {
65+
descriptor_set.bind(idx++, param->buffer());
66+
}
67+
return idx;
68+
}
69+
70+
void bind_staging_to_descriptor_set(
71+
api::StorageBuffer& staging,
72+
api::DescriptorSet& descriptor_set,
73+
const uint32_t idx) {
74+
descriptor_set.bind(idx, staging.buffer());
75+
}
76+
6677
} // namespace vulkan
6778
} // namespace native
6879
} // namespace at

backends/vulkan/runtime/graph/ops/utils/BindingUtils.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,38 @@ namespace at {
1616
namespace native {
1717
namespace vulkan {
1818

19+
//
20+
// For objects in the graph
21+
//
22+
1923
void bind_tensor_to_descriptor_set(
2024
vTensor& tensor,
2125
api::PipelineBarrier& pipeline_barrier,
2226
const api::MemoryAccessType accessType,
2327
api::DescriptorSet& descriptor_set,
2428
const uint32_t idx);
2529

26-
void bind_staging_to_descriptor_set(
27-
api::StorageBuffer& staging,
28-
api::DescriptorSet& descriptor_set,
29-
const uint32_t idx);
30-
3130
uint32_t bind_values_to_descriptor_set(
3231
ComputeGraph* graph,
3332
const std::vector<ArgGroup>& args,
3433
api::PipelineBarrier& pipeline_barrier,
3534
api::DescriptorSet& descriptor_set,
3635
const uint32_t base_idx);
3736

37+
//
38+
// For objects NOT in the graph
39+
//
40+
41+
uint32_t bind_params_to_descriptor_set(
42+
std::vector<std::shared_ptr<api::UniformParamsBuffer>>& params,
43+
api::DescriptorSet& descriptor_set,
44+
const uint32_t base_idx);
45+
46+
void bind_staging_to_descriptor_set(
47+
api::StorageBuffer& staging,
48+
api::DescriptorSet& descriptor_set,
49+
const uint32_t idx);
50+
3851
} // namespace vulkan
3952
} // namespace native
4053
} // namespace at

0 commit comments

Comments
 (0)