Skip to content

Commit fad8702

Browse files
committed
[ET-VK] Serialize all scalar types
Pull Request resolved: #2414 Adds all the [non-quantized scalar types from ATen Vulkan](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/vulkan/api/Types.h#L21-L26). ``` _(uint8_t, VK_FORMAT_R8G8B8A8_UINT, Byte) \ _(int8_t, VK_FORMAT_R8G8B8A8_SINT, Char) \ _(int32_t, VK_FORMAT_R32G32B32A32_SINT, Int) \ _(bool, VK_FORMAT_R8G8B8A8_SINT, Bool) \ _(unsigned short, VK_FORMAT_R16G16B16A16_SFLOAT, Half) \ _(float, VK_FORMAT_FLOAT4, Float) \ ``` Tensor images and buffers will now be created according to this `dtype`. It's up to contributors to associate their shader to one of these dtypes, e.g., in [`all_shaders.yaml`](https://github.com/pytorch/executorch/blob/main/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml#L32-L46). ``` image_to_nchw: ... generate_variant_forall: DTYPE: - VALUE: "half" SUFFIX: "half" - VALUE: "float" SUFFIX: "float" ... ``` Differential Revision: [D54873255](https://our.internmc.facebook.com/intern/diff/D54873255/) ghstack-source-id: 218582139
1 parent 4321bc5 commit fad8702

File tree

5 files changed

+39
-19
lines changed

5 files changed

+39
-19
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,18 @@ const uint8_t* getConstantDataPtr(
5757

5858
api::ScalarType get_scalar_type(const vkgraph::VkDataType& vk_datatype) {
5959
switch (vk_datatype) {
60-
case (vkgraph::VkDataType::fp32): {
60+
case vkgraph::VkDataType::BOOL:
61+
return api::kBool;
62+
case vkgraph::VkDataType::UINT8:
63+
return api::kByte;
64+
case vkgraph::VkDataType::INT8:
65+
return api::kChar;
66+
case vkgraph::VkDataType::INT32:
67+
return api::kInt;
68+
case vkgraph::VkDataType::FLOAT16:
69+
return api::kHalf;
70+
case vkgraph::VkDataType::FLOAT32:
6171
return api::kFloat;
62-
}
6372
}
6473
}
6574

backends/vulkan/serialization/schema.fbs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ table OperatorCall {
1010
args:[int];
1111
}
1212

13-
enum VkDataType : short {
14-
// IEEE754 single-precision floating-point.
15-
fp32 = 0,
13+
enum VkDataType : byte {
14+
BOOL = 0,
15+
UINT8 = 1,
16+
INT8 = 2,
17+
INT32 = 3,
18+
FLOAT16 = 4,
19+
FLOAT32 = 5,
1620
}
1721

1822
table VkTensor {

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,21 @@ def __init__(self, program: ExportedProgram) -> None:
3535

3636
@staticmethod
3737
def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType:
38-
# TODO(T182302927): Support more dtypes including float16, int(32|64).
39-
if torch_dtype == torch.float32:
40-
return vk_graph_schema.VkDataType.fp32
38+
if torch_dtype == torch.bool:
39+
return vk_graph_schema.VkDataType.BOOL
40+
elif torch_dtype == torch.uint8:
41+
return vk_graph_schema.VkDataType.UINT8
42+
elif torch_dtype == torch.int8:
43+
return vk_graph_schema.VkDataType.INT8
44+
elif torch_dtype == torch.int32:
45+
return vk_graph_schema.VkDataType.INT32
46+
elif torch_dtype == torch.float16:
47+
return vk_graph_schema.VkDataType.FLOAT16
48+
elif torch_dtype == torch.float32:
49+
return vk_graph_schema.VkDataType.FLOAT32
50+
# Narrowing conversion for index tensor produced by max_poolNd_with_indices.
51+
elif torch_dtype == torch.int64:
52+
return vk_graph_schema.VkDataType.INT32
4153
else:
4254
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
4355

backends/vulkan/serialization/vulkan_graph_schema.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ class OperatorCall:
2222

2323

2424
class VkDataType(IntEnum):
25-
fp32 = 0
25+
BOOL = 0
26+
UINT8 = 1
27+
INT8 = 2
28+
INT32 = 3
29+
FLOAT16 = 4
30+
FLOAT32 = 5
2631

2732

2833
@dataclass

backends/vulkan/vulkan_preprocess.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from typing import final, List
88

9-
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
10-
119
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
1210
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
1311
serialize_vulkan_graph,
@@ -25,20 +23,12 @@
2523
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2624

2725
from executorch.exir.program._program import _copy_module
28-
from torch import dtype, float32
2926

3027
DEFAULT_DEBUG_HANDLE = 65535
3128

3229

3330
@final
3431
class VulkanBackend(BackendDetails):
35-
@staticmethod
36-
def get_vk_datatype(torch_dtype: dtype) -> vk_graph_schema.VkDataType:
37-
if torch_dtype == float32:
38-
return vk_graph_schema.VkDataType.fp32
39-
else:
40-
raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})")
41-
4232
@classmethod
4333
# pyre-ignore
4434
def preprocess( # noqa: C901

0 commit comments

Comments
 (0)