Skip to content

Commit 21ea435

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 # Context Since we enabled the possibility for double support in an earlier diff, this enables double support for quantization and dequantization. Since there are limitations to how 64bit can be supported, the expectation is that IO is to be downgraded to 32bit. # Changes We create additional test cases for double support and make sure to pass in the double if its permitted (it's only allowed in buffers), and we also make sure to include double variants in the corresponding YAML files for quantization and dequantization. ghstack-source-id: 290156616 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent e2c5b16 commit 21ea435

File tree

8 files changed

+96
-2
lines changed

8 files changed

+96
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ $if MODE == "per_tensor":
6767
[[unroll]] for (int i = 0; i < 4; ++i) {
6868
IN_T qvalue = IN_T(intex[i]);
6969
OUT_T value = dequantize_val(qvalue, scale, zero_point);
70-
outtex[i] = value;
70+
$if OUT_DTYPE == "double":
71+
outtex[i] = float(value);
72+
$else:
73+
outtex[i] = value;
7174
}
7275
write_texel(t_out, pos, outtex);
7376

@@ -110,7 +113,10 @@ $if MODE == "per_token":
110113
[[unroll]] for (int i = 0; i < 4; ++i) {
111114
IN_T qvalue = IN_T(intex[i]);
112115
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
113-
outtex[i] = value;
116+
$if OUT_DTYPE == "double":
117+
outtex[i] = float(value);
118+
$else:
119+
outtex[i] = value;
114120
}
115121

116122
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void quantize_per_tensor_impl(
162162

163163
// Verify input is a floating point type
164164
VK_CHECK_COND(
165+
graph.dtype_of(input) == vkapi::kDouble ||
165166
graph.dtype_of(input) == vkapi::kFloat ||
166167
graph.dtype_of(input) == vkapi::kHalf);
167168

@@ -185,6 +186,7 @@ void quantize_per_token_impl(
185186

186187
// Verify input is a floating point type
187188
VK_CHECK_COND(
189+
graph.dtype_of(input) == vkapi::kDouble ||
188190
graph.dtype_of(input) == vkapi::kFloat ||
189191
graph.dtype_of(input) == vkapi::kHalf);
190192

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,12 @@ void test_vulkan_dequantize_per_tensor(
365365
vkcompute::utils::kBuffer,
366366
vkcompute::utils::kBuffer);
367367

368+
// Telling the system to expect a float instead of a double
369+
// since the shader can only return 32bit anyways
370+
if (out_dtype == at::kDouble) {
371+
out_dtype = at::kFloat;
372+
}
373+
368374
// Test with texture storage
369375
test_vulkan_dequantize_per_tensor_impl(
370376
input_sizes,
@@ -399,6 +405,12 @@ void test_vulkan_dequantize_per_token(
399405
vkcompute::utils::kBuffer,
400406
vkcompute::utils::kBuffer);
401407

408+
// Telling the system to expect a float instead of a double
409+
// since the shader can only return 32bit anyways
410+
if (out_dtype == at::kDouble) {
411+
out_dtype = at::kFloat;
412+
}
413+
402414
// Test with texture storage
403415
test_vulkan_dequantize_per_token_impl(
404416
input_sizes,
@@ -768,6 +780,19 @@ TEST(
768780
at::kHalf); // output dtype
769781
}
770782

783+
TEST(
784+
VulkanDequantizePerTensorTest,
785+
test_vulkan_dequantize_per_tensor_int32_to_double) {
786+
test_vulkan_dequantize_per_tensor(
787+
{2, 4, 3}, // input sizes
788+
0.0001, // scale
789+
100, // zero_point
790+
-2147483648, // quant_min
791+
2147483647, // quant_max
792+
at::kInt, // input dtype
793+
at::kDouble); // output dtype
794+
}
795+
771796
void test_reference_dequantize_per_token(
772797
const std::vector<int>& input_sizes,
773798
const std::vector<float>& scales,
@@ -1233,3 +1258,19 @@ TEST(
12331258
at::kInt, // input dtype
12341259
at::kHalf); // output dtype
12351260
}
1261+
1262+
TEST(
1263+
VulkanDequantizePerTokenTest,
1264+
test_vulkan_dequantize_per_token_int32_to_double) {
1265+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1266+
std::vector<int> zero_points = {100, -100, 50, -50};
1267+
1268+
test_vulkan_dequantize_per_token(
1269+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1270+
scales,
1271+
zero_points,
1272+
-2147483648, // quant_min
1273+
2147483647, // quant_max
1274+
at::kInt, // input dtype
1275+
at::kDouble); // output dtype
1276+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ void test_vulkan_quantize_per_tensor(
314314
vkcompute::utils::kBuffer,
315315
vkcompute::utils::kBuffer);
316316

317+
// If the in_dtype is a double, convert to float for texture implementation
318+
// since they don't support 64bit as inputs
319+
if (in_dtype == at::kDouble) {
320+
in_dtype = at::kFloat;
321+
}
322+
317323
// Test with texture storage
318324
test_vulkan_quantize_per_tensor_impl(
319325
input_sizes,
@@ -348,6 +354,12 @@ void test_vulkan_quantize_per_token(
348354
vkcompute::utils::kBuffer,
349355
vkcompute::utils::kBuffer);
350356

357+
// If the in_dtype is a double, convert to float for texture implementation
358+
// since they don't support 64bit as inputs
359+
if (in_dtype == at::kDouble) {
360+
in_dtype = at::kFloat;
361+
}
362+
351363
// Test with texture storage
352364
test_vulkan_quantize_per_token_impl(
353365
input_sizes,
@@ -639,6 +651,19 @@ TEST(
639651
at::kChar); // output dtype
640652
}
641653

654+
TEST(
655+
VulkanQuantizePerTensorTest,
656+
test_vulkan_quantize_per_tensor_double_to_int8) {
657+
test_vulkan_quantize_per_tensor(
658+
{2, 3}, // input sizes
659+
0.01, // scale
660+
1, // zero_point
661+
-128, // quant_min
662+
127, // quant_max
663+
at::kDouble, // input dtype
664+
at::kChar); // output dtype
665+
}
666+
642667
void test_reference_quantize_per_token(
643668
const std::vector<int>& input_sizes,
644669
const std::vector<float>& pre_scales,
@@ -1033,3 +1058,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10331058
at::kHalf, // input dtype
10341059
at::kChar); // output dtype
10351060
}
1061+
1062+
TEST(
1063+
VulkanQuantizePerTensorTest,
1064+
test_vulkan_quantize_per_token_double_to_int8) {
1065+
std::vector<float> scales = {0.1, 0.2};
1066+
std::vector<int> zero_points = {0, 5};
1067+
1068+
test_vulkan_quantize_per_token(
1069+
{2, 2}, // input sizes (2*2=4 tokens)
1070+
scales,
1071+
zero_points,
1072+
-128, // quant_min
1073+
127, // quant_max
1074+
at::kDouble, // input dtype
1075+
at::kChar); // output dtype
1076+
}

0 commit comments

Comments
 (0)