Skip to content

[ET-VK] Replacing the use of uvec3 with WorkgroupSize class to reduce memory usage and improve processing speed #8671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ void Context::cmd_reset_querypool() {
void Context::report_shader_dispatch_start(
const std::string& shader_name,
const utils::uvec3& global_wg_size,
const utils::uvec3& local_wg_size,
const utils::WorkgroupSize& local_wg_size,
const uint32_t dispatch_id) {
if (querypool_) {
querypool_.shader_profile_begin(
cmd_,
dispatch_id,
shader_name,
vkapi::create_extent3d(global_wg_size),
vkapi::create_extent3d(local_wg_size));
vkapi::create_extent3d((utils::uvec3)local_wg_size));
}
}

Expand Down Expand Up @@ -115,7 +115,7 @@ void Context::check_device_capabilities(const vkapi::ShaderInfo& shader) {

vkapi::DescriptorSet Context::get_descriptor_set(
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& local_workgroup_size,
const utils::WorkgroupSize& local_workgroup_size,
const vkapi::SpecVarList& additional_constants,
const uint32_t push_constants_size) {
VkDescriptorSetLayout shader_layout =
Expand All @@ -124,17 +124,11 @@ vkapi::DescriptorSet Context::get_descriptor_set(
VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout, push_constants_size);

vkapi::SpecVarList spec_constants = {
SV(local_workgroup_size[0u]),
SV(local_workgroup_size[1u]),
SV(local_workgroup_size[2u])};

spec_constants.append(additional_constants);

VkPipeline pipeline = pipeline_cache().retrieve(
{pipeline_layout_cache().retrieve(shader_layout, push_constants_size),
shader_cache().retrieve(shader_descriptor),
spec_constants});
additional_constants,
local_workgroup_size});

cmd_.bind_pipeline(pipeline, pipeline_layout, local_workgroup_size);

Expand Down
14 changes: 9 additions & 5 deletions backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName

#include <executorch/backends/vulkan/runtime/utils/MacroUtils.h>
#include <executorch/backends/vulkan/runtime/utils/VecUtils.h>

#include <executorch/backends/vulkan/runtime/vk_api/Adapter.h>
#include <executorch/backends/vulkan/runtime/vk_api/Command.h>
Expand Down Expand Up @@ -150,7 +151,7 @@ class Context final {
void report_shader_dispatch_start(
const std::string& shader_name,
const utils::uvec3& global_wg_size,
const utils::uvec3& local_wg_size,
const utils::WorkgroupSize& local_wg_size,
const uint32_t dispatch_id = UINT32_MAX);

/*
Expand Down Expand Up @@ -189,13 +190,13 @@ class Context final {

vkapi::DescriptorSet get_descriptor_set(
const vkapi::ShaderInfo&,
const utils::uvec3&,
const utils::WorkgroupSize&,
const vkapi::SpecVarList&,
const uint32_t push_constants_size);

inline vkapi::DescriptorSet get_descriptor_set(
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& local_work_group_size) {
const utils::WorkgroupSize& local_work_group_size) {
return get_descriptor_set(shader_descriptor, local_work_group_size, {}, 0u);
}

Expand Down Expand Up @@ -362,14 +363,17 @@ inline bool Context::submit_compute_job(
report_shader_dispatch_start(
shader.kernel_name,
global_work_group,
local_work_group_size,
utils::WorkgroupSize(local_work_group_size),
dispatch_id);

// Factor out template parameter independent code to minimize code bloat.
// Note that push constants are not exposed yet via this API, therefore the
// push constants size is assumed to be 0.
vkapi::DescriptorSet descriptor_set = get_descriptor_set(
shader, local_work_group_size, specialization_constants, 0u);
shader,
utils::WorkgroupSize(local_work_group_size),
specialization_constants,
0u);

detail::bind(
descriptor_set,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/BlitNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void BlitNode::encode(ComputeGraph* graph) {
kernel_name += vkapi::to_string(dst_tensor->dtype());

context->report_shader_dispatch_start(
kernel_name, utils::uvec3(), utils::uvec3(), node_id_);
kernel_name, utils::uvec3(), utils::WorkgroupSize(), node_id_);

context->register_blit(
pipeline_barrier,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/DispatchNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class DispatchNode final : public ExecuteNode {
protected:
const vkapi::ShaderInfo shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const utils::WorkgroupSize local_workgroup_size_;
const vkapi::ParamsBindList params_;
const vkapi::SpecVarList spec_vars_;
const std::vector<PushConstantDataInfo> push_constants_;
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/PrepackNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
// bound with the correct image layout.
{
vkapi::PipelineBarrier pipeline_barrier{};
vkapi::DescriptorSet descriptor_set =
context->get_descriptor_set(noop_shader_, {1, 1, 1});
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
noop_shader_, utils::WorkgroupSize(1, 1, 1));

bind_tensor_to_descriptor_set(
*packed,
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/graph/ops/PrepackNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class PrepackNode final {
const vkapi::ShaderInfo shader_;
vkapi::ShaderInfo noop_shader_;
const utils::uvec3 global_workgroup_size_;
const utils::uvec3 local_workgroup_size_;
const utils::WorkgroupSize local_workgroup_size_;
const ValueRef tref_;
const ValueRef packed_;
const vkapi::ParamsBindList params_;
Expand Down
44 changes: 44 additions & 0 deletions backends/vulkan/runtime/utils/VecUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,5 +479,49 @@ inline int64_t multiply_integers(Iter begin, Iter end) {
begin, end, static_cast<int64_t>(1), std::multiplies<>());
}

class WorkgroupSize final {
uint32_t val;

public:
explicit WorkgroupSize() : val(0) {}
explicit WorkgroupSize(const uint32_t x, const uint32_t y, const uint32_t z) {
// shift numbers by multiple of 11 bits, since each local workgroup axis can
// be 1024 at most and which is 0x400. only z axis can't store 1024, because
// it would overflow uint32_t storage.
if (z == 1024) {
throw std::runtime_error(
"Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage");
}
val = x | (y << 11) | (z << 22);
}

explicit WorkgroupSize(const uvec3& vec) {
// shift numbers by multiple of 11 bits, since each local workgroup axis can
// be 1024 at most and which is 0x400. only z axis can't store 1024, because
// it would overflow uint32_t storage.
if (vec[2u] == 1024) {
throw std::runtime_error(
"Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage");
}
val = vec[0u] | (vec[1u] << 11) | (vec[2u] << 22);
}

explicit inline operator uvec3() const {
return {
val & 0x7ffu,
(val >> 11) & 0x7ffu,
(val >> 22),
};
}

explicit inline operator uint32_t() const {
return val;
}

inline constexpr uint32_t operator[](const int idx) const {
return (val >> (11 * idx)) & 0x7ffu;
}
};

} // namespace utils
} // namespace vkcompute
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/vk_api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void CommandBuffer::end() {
void CommandBuffer::bind_pipeline(
VkPipeline pipeline,
VkPipelineLayout pipeline_layout,
const utils::uvec3 local_workgroup_size) {
const utils::WorkgroupSize local_workgroup_size) {
VK_CHECK_COND(
state_ == CommandBuffer::State::RECORDING,
"Vulkan CommandBuffer: called bind_pipeline() on a command buffer whose state "
Expand Down
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/vk_api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class CommandBuffer final {
struct Bound {
VkPipeline pipeline;
VkPipelineLayout pipeline_layout;
utils::uvec3 local_workgroup_size;
utils::WorkgroupSize local_workgroup_size;
VkDescriptorSet descriptors;

explicit Bound()
Expand All @@ -63,7 +63,7 @@ class CommandBuffer final {
inline void reset() {
pipeline = VK_NULL_HANDLE;
pipeline_layout = VK_NULL_HANDLE;
local_workgroup_size = {0u, 0u, 0u};
local_workgroup_size = utils::WorkgroupSize{0u, 0u, 0u};
descriptors = VK_NULL_HANDLE;
}
};
Expand All @@ -87,7 +87,7 @@ class CommandBuffer final {
void begin();
void end();

void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::WorkgroupSize);
void bind_descriptors(VkDescriptorSet);
void set_push_constants(VkPipelineLayout, const void*, uint32_t);

Expand Down
27 changes: 22 additions & 5 deletions backends/vulkan/runtime/vk_api/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ void SpecVarList::append(const SpecVarList& other) {
vars.insert(vars.end(), other.vars.begin(), other.vars.end());
}

void SpecVarList::reserve(const size_t size) {
vars.reserve(size);
}

void SpecVarList::append(const SpecVar& other) {
vars.push_back(other);
}

std::vector<VkSpecializationMapEntry> SpecVarList::generate_map_entries()
const {
std::vector<VkSpecializationMapEntry> map_entries;
Expand Down Expand Up @@ -267,14 +275,23 @@ ComputePipeline::ComputePipeline(
const ComputePipeline::Descriptor& descriptor,
VkPipelineCache pipeline_cache)
: device_(device), handle_{VK_NULL_HANDLE} {
std::vector<VkSpecializationMapEntry> map_entries =
descriptor.specialization_constants.generate_map_entries();
SpecVarList specialization_constants;

specialization_constants.reserve(
3 + descriptor.specialization_constants.size());
specialization_constants.append(descriptor.local_wg_size[0]);
specialization_constants.append(descriptor.local_wg_size[1]);
specialization_constants.append(descriptor.local_wg_size[2]);

specialization_constants.append(descriptor.specialization_constants);
const std::vector<VkSpecializationMapEntry> map_entries =
specialization_constants.generate_map_entries();

const VkSpecializationInfo specialization_info{
descriptor.specialization_constants.size(), // mapEntryCount
specialization_constants.size(), // mapEntryCount
map_entries.data(), // pMapEntries
descriptor.specialization_constants.data_nbytes(), // dataSize
descriptor.specialization_constants.data(), // pData
specialization_constants.data_nbytes(), // dataSize
specialization_constants.data(), // pData
};

const VkPipelineShaderStageCreateInfo shader_stage_create_info{
Expand Down
8 changes: 8 additions & 0 deletions backends/vulkan/runtime/vk_api/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ class SpecVarList final {

void append(const SpecVarList& other);

void reserve(const size_t size);

void append(const SpecVar& other);

std::vector<VkSpecializationMapEntry> generate_map_entries() const;

friend bool operator==(const SpecVarList& lhs, const SpecVarList& rhs);
Expand Down Expand Up @@ -152,6 +156,7 @@ class ComputePipeline final {
VkPipelineLayout pipeline_layout;
VkShaderModule shader_module;
SpecVarList specialization_constants;
utils::WorkgroupSize local_wg_size;
};

explicit ComputePipeline(
Expand Down Expand Up @@ -269,6 +274,9 @@ class ComputePipelineCache final {
seed = utils::hash_combine(seed, new_seed);
}

seed = utils::hash_combine(
seed, std::hash<uint32_t>()((uint32_t)descriptor.local_wg_size));

return seed;
}
};
Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/runtime/vk_api/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct ShaderInfo final {
ShaderLayout::Signature kernel_layout{};

// Shader Metadata
utils::uvec3 out_tile_size{1u, 1u, 1u};
utils::WorkgroupSize out_tile_size{1u, 1u, 1u};
bool requires_shader_int16 = false;
bool requires_16bit_storage = false;
bool requires_8bit_storage = false;
Expand Down
Loading