Skip to content

[ET-VK] Serialize all scalar types #2414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,18 @@ const uint8_t* getConstantDataPtr(

api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
switch (vk_datatype) {
case (vkgraph::VkDataType::fp32): {
case vkgraph::VkDataType::BOOL:
return api::kBool;
case vkgraph::VkDataType::UINT8:
return api::kByte;
case vkgraph::VkDataType::INT8:
return api::kChar;
case vkgraph::VkDataType::INT32:
return api::kInt;
case vkgraph::VkDataType::FLOAT16:
return api::kHalf;
case vkgraph::VkDataType::FLOAT32:
return api::kFloat;
}
}
}

Expand Down
10 changes: 7 additions & 3 deletions backends/vulkan/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@ table OperatorCall {
args:[int];
}

enum VkDataType : short {
// IEEE754 single-precision floating-point.
fp32 = 0,
enum VkDataType : byte {
BOOL = 0,
UINT8 = 1,
INT8 = 2,
INT32 = 3,
FLOAT16 = 4,
FLOAT32 = 5,
}

table VkTensor {
Expand Down
18 changes: 15 additions & 3 deletions backends/vulkan/serialization/vulkan_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,21 @@ def __init__(self, program: ExportedProgram) -> None:

@staticmethod
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
# TODO(T182302927): Support more dtypes including float16, int(32|64).
if torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.fp32
if torch_dtype == torch.bool:
return vk_graph_schema.VkDataType.BOOL
elif torch_dtype == torch.uint8:
return vk_graph_schema.VkDataType.UINT8
elif torch_dtype == torch.int8:
return vk_graph_schema.VkDataType.INT8
elif torch_dtype == torch.int32:
return vk_graph_schema.VkDataType.INT32
elif torch_dtype == torch.float16:
return vk_graph_schema.VkDataType.FLOAT16
elif torch_dtype == torch.float32:
return vk_graph_schema.VkDataType.FLOAT32
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
elif torch_dtype == torch.int64:
return vk_graph_schema.VkDataType.INT32
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

Expand Down
7 changes: 6 additions & 1 deletion backends/vulkan/serialization/vulkan_graph_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ class OperatorCall:


class VkDataType(IntEnum):
fp32 = 0
BOOL = 0
UINT8 = 1
INT8 = 2
INT32 = 3
FLOAT16 = 4
FLOAT32 = 5


@dataclass
Expand Down
10 changes: 0 additions & 10 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from typing import final, List

import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
serialize_vulkan_graph,
Expand All @@ -25,20 +23,12 @@
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.exir.program._program import _copy_module
from torch import dtype, float32

DEFAULT_DEBUG_HANDLE = 65535


@final
class VulkanBackend(BackendDetails):
@staticmethod
def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDataType:
if torch_dtype == float32:
return vk_graph_schema.VkDataType.fp32
else:
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")

@classmethod
# pyre-ignore
def preprocess( # noqa: C901
Expand Down