Skip to content

Commit 3519560

Browse files
author
morelos
committed
[ET-VK][Ops] enabling double support for quantization and dequantization ops
Pull Request resolved: #11553 # Context Since we enabled the possibility for double support in an earlier diff, this enables double support for quantization and dequantization. Since there are limitations to how 64bit can be supported, the expectation is that IO is to be downgraded to 32bit. # Changes We create additional test cases for double support and make sure to pass in the double if its permitted (it's only allowed in buffers), and we also make sure to include double variants in the corresponding YAML files for quantization and dequantization. ghstack-source-id: 290376485 @exported-using-ghexport Differential Revision: [D76289197](https://our.internmc.facebook.com/intern/diff/D76289197/)
1 parent d31ab2b commit 3519560

File tree

8 files changed

+96
-2
lines changed

8 files changed

+96
-2
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ $if MODE == "per_tensor":
6767
[[unroll]] for (int i = 0; i < 4; ++i) {
6868
IN_T qvalue = IN_T(intex[i]);
6969
OUT_T value = dequantize_val(qvalue, scale, zero_point);
70-
outtex[i] = value;
70+
$if OUT_DTYPE == "double":
71+
outtex[i] = float(value);
72+
$else:
73+
outtex[i] = value;
7174
}
7275
write_texel(t_out, pos, outtex);
7376

@@ -110,7 +113,10 @@ $if MODE == "per_token":
110113
[[unroll]] for (int i = 0; i < 4; ++i) {
111114
IN_T qvalue = IN_T(intex[i]);
112115
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
113-
outtex[i] = value;
116+
$if OUT_DTYPE == "double":
117+
outtex[i] = float(value);
118+
$else:
119+
outtex[i] = value;
114120
}
115121

116122
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ void quantize_per_tensor_impl(
162162

163163
// Verify input is a floating point type
164164
VK_CHECK_COND(
165+
graph.dtype_of(input) == vkapi::kDouble ||
165166
graph.dtype_of(input) == vkapi::kFloat ||
166167
graph.dtype_of(input) == vkapi::kHalf);
167168

@@ -185,6 +186,7 @@ void quantize_per_token_impl(
185186

186187
// Verify input is a floating point type
187188
VK_CHECK_COND(
189+
graph.dtype_of(input) == vkapi::kDouble ||
188190
graph.dtype_of(input) == vkapi::kFloat ||
189191
graph.dtype_of(input) == vkapi::kHalf);
190192

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,12 @@ void test_vulkan_dequantize_per_tensor(
365365
vkcompute::utils::kBuffer,
366366
vkcompute::utils::kBuffer);
367367

368+
// Telling the system to expect a float instead of a double
369+
// since the shader can only return 32bit anyways
370+
if (out_dtype == at::kDouble) {
371+
out_dtype = at::kFloat;
372+
}
373+
368374
// Test with texture storage
369375
test_vulkan_dequantize_per_tensor_impl(
370376
input_sizes,
@@ -399,6 +405,12 @@ void test_vulkan_dequantize_per_token(
399405
vkcompute::utils::kBuffer,
400406
vkcompute::utils::kBuffer);
401407

408+
// Telling the system to expect a float instead of a double
409+
// since the shader can only return 32bit anyways
410+
if (out_dtype == at::kDouble) {
411+
out_dtype = at::kFloat;
412+
}
413+
402414
// Test with texture storage
403415
test_vulkan_dequantize_per_token_impl(
404416
input_sizes,
@@ -793,6 +805,19 @@ TEST(
793805
at::kHalf); // output dtype
794806
}
795807

808+
TEST(
809+
VulkanDequantizePerTensorTest,
810+
test_vulkan_dequantize_per_tensor_int32_to_double) {
811+
test_vulkan_dequantize_per_tensor(
812+
{2, 4, 3}, // input sizes
813+
0.0001, // scale
814+
100, // zero_point
815+
-2147483648, // quant_min
816+
2147483647, // quant_max
817+
at::kInt, // input dtype
818+
at::kDouble); // output dtype
819+
}
820+
796821
void test_reference_dequantize_per_token(
797822
const std::vector<int>& input_sizes,
798823
const std::vector<float>& scales,
@@ -1283,3 +1308,19 @@ TEST(
12831308
at::kInt, // input dtype
12841309
at::kHalf); // output dtype
12851310
}
1311+
1312+
TEST(
1313+
VulkanDequantizePerTokenTest,
1314+
test_vulkan_dequantize_per_token_int32_to_double) {
1315+
std::vector<float> scales = {0.0001, 0.0002, 0.0003, 0.0};
1316+
std::vector<int> zero_points = {100, -100, 50, -50};
1317+
1318+
test_vulkan_dequantize_per_token(
1319+
{2, 2, 8}, // input sizes (2*2=4 tokens)
1320+
scales,
1321+
zero_points,
1322+
-2147483648, // quant_min
1323+
2147483647, // quant_max
1324+
at::kInt, // input dtype
1325+
at::kDouble); // output dtype
1326+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315
vkcompute::utils::kBuffer,
316316
vkcompute::utils::kBuffer);
317317

318+
// If the in_dtype is a double, convert to float for texture implementation
319+
// since they don't support 64bit as inputs
320+
if (in_dtype == at::kDouble) {
321+
in_dtype = at::kFloat;
322+
}
323+
318324
// Test with texture storage
319325
test_vulkan_quantize_per_tensor_impl(
320326
input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355
vkcompute::utils::kBuffer,
350356
vkcompute::utils::kBuffer);
351357

358+
// If the in_dtype is a double, convert to float for texture implementation
359+
// since they don't support 64bit as inputs
360+
if (in_dtype == at::kDouble) {
361+
in_dtype = at::kFloat;
362+
}
363+
352364
// Test with texture storage
353365
test_vulkan_quantize_per_token_impl(
354366
input_sizes,
@@ -655,6 +667,19 @@ TEST(
655667
at::kChar); // output dtype
656668
}
657669

670+
TEST(
671+
VulkanQuantizePerTensorTest,
672+
test_vulkan_quantize_per_tensor_double_to_int8) {
673+
test_vulkan_quantize_per_tensor(
674+
{2, 3}, // input sizes
675+
0.01, // scale
676+
1, // zero_point
677+
-128, // quant_min
678+
127, // quant_max
679+
at::kDouble, // input dtype
680+
at::kChar); // output dtype
681+
}
682+
658683
void test_reference_quantize_per_token(
659684
const std::vector<int>& input_sizes,
660685
const std::vector<float>& pre_scales,
@@ -1069,3 +1094,19 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10691094
at::kHalf, // input dtype
10701095
at::kChar); // output dtype
10711096
}
1097+
1098+
TEST(
1099+
VulkanQuantizePerTensorTest,
1100+
test_vulkan_quantize_per_token_double_to_int8) {
1101+
std::vector<float> scales = {0.1, 0.2};
1102+
std::vector<int> zero_points = {0, 5};
1103+
1104+
test_vulkan_quantize_per_token(
1105+
{2, 2}, // input sizes (2*2=4 tokens)
1106+
scales,
1107+
zero_points,
1108+
-128, // quant_min
1109+
127, // quant_max
1110+
at::kDouble, // input dtype
1111+
at::kChar); // output dtype
1112+
}

0 commit comments

Comments
 (0)