diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index ad4434cf5af..c084a563544 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -479,5 +479,49 @@ inline int64_t multiply_integers(Iter begin, Iter end) { begin, end, static_cast(1), std::multiplies<>()); } +class WorkgroupSize final { + uint32_t val; + + public: + explicit WorkgroupSize() : val(0) {} + explicit WorkgroupSize(const uint32_t x, const uint32_t y, const uint32_t z) { + // shift numbers by multiple of 11 bits, since each local workgroup axis can + // be 1024 at most and which is 0x400. only z axis can't store 1024, because + // it would overflow uint32_t storage. + if (z == 1024) { + throw std::runtime_error( + "Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage"); + } + val = x | (y << 11) | (z << 22); + } + + explicit WorkgroupSize(const uvec3& vec) { + // shift numbers by multiple of 11 bits, since each local workgroup axis can + // be 1024 at most and which is 0x400. only z axis can't store 1024, because + // it would overflow uint32_t storage. + if (vec[2u] == 1024) { + throw std::runtime_error( + "Workgroup size in z axis cannot be 1024 because it would overflow uint32_t storage"); + } + val = vec[0u] | (vec[1u] << 11) | (vec[2u] << 22); + } + + explicit inline operator uvec3() const { + return { + val & 0x7ffu, + (val >> 11) & 0x7ffu, + (val >> 22), + }; + } + + explicit inline operator uint32_t() const { + return val; + } + + inline constexpr uint32_t operator[](const int idx) const { + return (val >> (11 * idx)) & 0x7ffu; + } +}; + } // namespace utils } // namespace vkcompute