From 45d573ae2fae64a15d00f922ab915f98aed3eca7 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Wed, 13 Mar 2024 15:40:39 -0700 Subject: [PATCH 1/2] [ET-VK] Serialize all scalar types Differential Revision: [D54873255](https://our.internmc.facebook.com/intern/diff/D54873255/) [ghstack-poisoned] --- backends/vulkan/runtime/VulkanBackend.cpp | 13 +++++++++++-- backends/vulkan/serialization/schema.fbs | 10 +++++++--- .../serialization/vulkan_graph_builder.py | 18 +++++++++++++++--- .../serialization/vulkan_graph_schema.py | 7 ++++++- backends/vulkan/vulkan_preprocess.py | 10 ---------- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 7ac85fc7725..62555adc734 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -57,9 +57,18 @@ const uint8_t* getConstantDataPtr( api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) { switch (vk_datatype) { - case (vkgraph::VkDataType::fp32): { + case vkgraph::VkDataType::BOOL: + return api::kBool; + case vkgraph::VkDataType::UINT8: + return api::kByte; + case vkgraph::VkDataType::INT8: + return api::kChar; + case vkgraph::VkDataType::INT32: + return api::kInt; + case vkgraph::VkDataType::FLOAT16: + return api::kHalf; + case vkgraph::VkDataType::FLOAT32: return api::kFloat; - } } } diff --git a/backends/vulkan/serialization/schema.fbs b/backends/vulkan/serialization/schema.fbs index e5139b5fd53..36f6120025a 100644 --- a/backends/vulkan/serialization/schema.fbs +++ b/backends/vulkan/serialization/schema.fbs @@ -10,9 +10,13 @@ table OperatorCall { args:[int]; } -enum VkDataType : short { - // IEEE754 single-precision floating-point. - fp32 = 0, +enum VkDataType : byte { + BOOL = 0, + UINT8 = 1, + INT8 = 2, + INT32 = 3, + FLOAT16 = 4, + FLOAT32 = 5, } table VkTensor { diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index a88f3029d12..f15e1557033 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -35,9 +35,21 @@ def __init__(self, program: ExportedProgram) -> None: @staticmethod def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: - # TODO(T182302927): Support more dtypes including float16, int(32|64). - if torch_dtype == torch.float32: - return vk_graph_schema.VkDataType.fp32 + if torch_dtype == torch.bool: + return vk_graph_schema.VkDataType.BOOL + elif torch_dtype == torch.uint8: + return vk_graph_schema.VkDataType.UINT8 + elif torch_dtype == torch.int8: + return vk_graph_schema.VkDataType.INT8 + elif torch_dtype == torch.int32: + return vk_graph_schema.VkDataType.INT32 + elif torch_dtype == torch.float16: + return vk_graph_schema.VkDataType.FLOAT16 + elif torch_dtype == torch.float32: + return vk_graph_schema.VkDataType.FLOAT32 + # Narrowing conversion for index tensor produced by max_poolNd_with_indices. + elif torch_dtype == torch.int64: + return vk_graph_schema.VkDataType.INT32 else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") diff --git a/backends/vulkan/serialization/vulkan_graph_schema.py b/backends/vulkan/serialization/vulkan_graph_schema.py index 1c5a05727b0..2edb02a910f 100644 --- a/backends/vulkan/serialization/vulkan_graph_schema.py +++ b/backends/vulkan/serialization/vulkan_graph_schema.py @@ -22,7 +22,12 @@ class OperatorCall: class VkDataType(IntEnum): - fp32 = 0 + BOOL = 0 + UINT8 = 1 + INT8 = 2 + INT32 = 3 + FLOAT16 = 4 + FLOAT32 = 5 @dataclass diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 27f42d1ec8f..91a85f15a1b 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -6,8 +6,6 @@ from typing import final, List -import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, @@ -25,20 +23,12 @@ from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.program._program import _copy_module -from torch import dtype, float32 DEFAULT_DEBUG_HANDLE = 65535 @final class VulkanBackend(BackendDetails): - @staticmethod - def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDataType: - if torch_dtype == float32: - return vk_graph_schema.VkDataType.fp32 - else: - raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - @classmethod # pyre-ignore def preprocess( # noqa: C901 From 444352bac3139fec9ffc445811b5280fdb6ada12 Mon Sep 17 00:00:00 2001 From: Jorge Pineda Date: Wed, 13 Mar 2024 15:41:42 -0700 Subject: [PATCH 2/2] Update on "[ET-VK] Serialize all scalar types" Differential Revision: [D54873255](https://our.internmc.facebook.com/intern/diff/D54873255/) [ghstack-poisoned]