66 * LICENSE file in the root directory of this source tree.
77 */
88
9+ // @lint-ignore-every CLANGTIDY clang-diagnostic-missing-field-initializers
10+
911#include < executorch/backends/vulkan/runtime/api/Adapter.h>
1012
1113#include < bitset>
@@ -21,15 +23,33 @@ PhysicalDevice::PhysicalDevice(VkPhysicalDevice physical_device_handle)
2123 : handle(physical_device_handle),
2224 properties{},
2325 memory_properties{},
26+ shader_16bit_storage{
27+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES},
28+ shader_8bit_storage{
29+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES},
30+ shader_float16_int8_types{
31+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES_KHR},
2432 queue_families{},
2533 num_compute_queues (0 ),
2634 has_unified_memory (false ),
2735 has_timestamps (properties.limits.timestampComputeAndGraphics),
28- timestamp_period (properties.limits.timestampPeriod) {
36+ timestamp_period (properties.limits.timestampPeriod),
37+ extension_features (&shader_16bit_storage) {
2938 // Extract physical device properties
3039 vkGetPhysicalDeviceProperties (handle, &properties);
3140 vkGetPhysicalDeviceMemoryProperties (handle, &memory_properties);
3241
42+ VkPhysicalDeviceFeatures2 features2{
43+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2};
44+
45+ // Create linked list to query availability of extensions
46+ features2.pNext = &shader_16bit_storage;
47+ shader_16bit_storage.pNext = &shader_8bit_storage;
48+ shader_8bit_storage.pNext = &shader_float16_int8_types;
49+ shader_float16_int8_types.pNext = nullptr ;
50+
51+ vkGetPhysicalDeviceFeatures2 (handle, &features2);
52+
3353 // Check if there are any memory types have both the HOST_VISIBLE and the
3454 // DEVICE_LOCAL property flags
3555 const VkMemoryPropertyFlags unified_memory_flags =
@@ -140,6 +160,9 @@ VkDevice create_logical_device(
140160#ifdef VK_KHR_portability_subset
141161 VK_KHR_PORTABILITY_SUBSET_EXTENSION_NAME,
142162#endif /* VK_KHR_portability_subset */
163+ VK_KHR_16BIT_STORAGE_EXTENSION_NAME,
164+ VK_KHR_8BIT_STORAGE_EXTENSION_NAME,
165+ VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME,
143166 };
144167
145168 std::vector<const char *> enabled_device_extensions;
@@ -148,7 +171,7 @@ VkDevice create_logical_device(
148171 enabled_device_extensions,
149172 requested_device_extensions);
150173
151- const VkDeviceCreateInfo device_create_info{
174+ VkDeviceCreateInfo device_create_info{
152175 VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, // sType
153176 nullptr , // pNext
154177 0u , // flags
@@ -162,6 +185,8 @@ VkDevice create_logical_device(
162185 nullptr , // pEnabledFeatures
163186 };
164187
188+ device_create_info.pNext = physical_device.extension_features ;
189+
165190 VkDevice handle = nullptr ;
166191 VK_CHECK (vkCreateDevice (
167192 physical_device.handle , &device_create_info, nullptr , &handle));
@@ -371,33 +396,53 @@ std::string Adapter::stringize() const {
371396 ss << " deviceType: " << device_type << std::endl;
372397 ss << " deviceName: " << properties.deviceName << std::endl;
373398
374- #define PRINT_LIMIT_PROP ( name ) \
375- ss << " " << std::left << std::setw (36 ) << #name << limits .name \
399+ #define PRINT_PROP ( struct, name ) \
400+ ss << " " << std::left << std::setw (36 ) << #name << struct .name \
376401 << std::endl;
377402
378- #define PRINT_LIMIT_PROP_VEC3 ( name ) \
379- ss << " " << std::left << std::setw (36 ) << #name << limits .name [0 ] \
380- << " ," << limits .name [1 ] << " ," << limits .name [2 ] << std::endl;
403+ #define PRINT_PROP_VEC3 ( struct, name ) \
404+ ss << " " << std::left << std::setw(36 ) << #name << struct .name[0 ] \
405+ << " ," << struct .name[1 ] << " ," << struct .name[2 ] << std::endl;
381406
382407 ss << " Physical Device Limits {" << std::endl;
383- PRINT_LIMIT_PROP (maxImageDimension1D);
384- PRINT_LIMIT_PROP (maxImageDimension2D);
385- PRINT_LIMIT_PROP (maxImageDimension3D);
386- PRINT_LIMIT_PROP (maxTexelBufferElements);
387- PRINT_LIMIT_PROP (maxPushConstantsSize);
388- PRINT_LIMIT_PROP (maxMemoryAllocationCount);
389- PRINT_LIMIT_PROP (maxSamplerAllocationCount);
390- PRINT_LIMIT_PROP (maxComputeSharedMemorySize);
391- PRINT_LIMIT_PROP_VEC3 (maxComputeWorkGroupCount);
392- PRINT_LIMIT_PROP (maxComputeWorkGroupInvocations);
393- PRINT_LIMIT_PROP_VEC3 (maxComputeWorkGroupSize);
408+ PRINT_PROP (limits, maxImageDimension1D);
409+ PRINT_PROP (limits, maxImageDimension2D);
410+ PRINT_PROP (limits, maxImageDimension3D);
411+ PRINT_PROP (limits, maxTexelBufferElements);
412+ PRINT_PROP (limits, maxPushConstantsSize);
413+ PRINT_PROP (limits, maxMemoryAllocationCount);
414+ PRINT_PROP (limits, maxSamplerAllocationCount);
415+ PRINT_PROP (limits, maxComputeSharedMemorySize);
416+ PRINT_PROP_VEC3 (limits, maxComputeWorkGroupCount);
417+ PRINT_PROP (limits, maxComputeWorkGroupInvocations);
418+ PRINT_PROP_VEC3 (limits, maxComputeWorkGroupSize);
419+ ss << " }" << std::endl;
420+
421+ ss << " 16bit Storage Features {" << std::endl;
422+ PRINT_PROP (physical_device_.shader_16bit_storage, storageBuffer16BitAccess);
423+ PRINT_PROP (
424+ physical_device_.shader_16bit_storage,
425+ uniformAndStorageBuffer16BitAccess);
426+ PRINT_PROP (physical_device_.shader_16bit_storage, storagePushConstant16);
427+ PRINT_PROP (physical_device_.shader_16bit_storage, storageInputOutput16);
428+ ss << " }" << std::endl;
429+
430+ ss << " 8bit Storage Features {" << std::endl;
431+ PRINT_PROP (physical_device_.shader_8bit_storage, storageBuffer8BitAccess);
432+ PRINT_PROP (
433+ physical_device_.shader_8bit_storage, uniformAndStorageBuffer8BitAccess);
434+ PRINT_PROP (physical_device_.shader_8bit_storage, storagePushConstant8);
435+ ss << " }" << std::endl;
436+
437+ ss << " Shader 16bit and 8bit Features {" << std::endl;
438+ PRINT_PROP (physical_device_.shader_float16_int8_types, shaderFloat16);
439+ PRINT_PROP (physical_device_.shader_float16_int8_types, shaderInt8);
394440 ss << " }" << std::endl;
395- ss << " }" << std::endl;
396- ;
397441
398442 const VkPhysicalDeviceMemoryProperties& mem_props =
399443 physical_device_.memory_properties;
400444
445+ ss << " }" << std::endl;
401446 ss << " Memory Info {" << std::endl;
402447 ss << " Memory Types [" << std::endl;
403448 for (size_t i = 0 ; i < mem_props.memoryTypeCount; ++i) {
@@ -432,6 +477,9 @@ std::string Adapter::stringize() const {
432477 ss << " ]" << std::endl;
433478 ss << " }" ;
434479
480+ #undef PRINT_PROP
481+ #undef PRINT_PROP_VEC3
482+
435483 return ss.str();
436484}
437485
0 commit comments