Skip to content

Commit dad8b5b

Browse files
authored
[ET-VK][Ops] dequantize_per_tensor.default test setup
Differential Revision: D76267054 Pull Request resolved: #11481
1 parent 31e17bd commit dad8b5b

File tree

1 file changed

+385
-0
lines changed

1 file changed

+385
-0
lines changed

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cassert>
2323
#include <iostream>
24+
#include <limits>
2425

2526
namespace torch {
2627
namespace executor {
@@ -180,3 +181,387 @@ void check_dequantize_args(
180181
")");
181182
}
182183
}
184+
185+
//
186+
// Reference Implementation
187+
//
188+
189+
/*
190+
* Reference implementation of dequantize_per_tensor
191+
*/
192+
at::Tensor dequantize_per_tensor_reference_impl(
193+
const at::Tensor& input,
194+
double scale,
195+
int64_t zero_point,
196+
int64_t quant_min,
197+
int64_t quant_max,
198+
at::ScalarType dtype,
199+
at::ScalarType out_dtype) {
200+
// Create output tensor with the target dtype
201+
at::Tensor out = at::empty_like(input, out_dtype);
202+
203+
// Dequantize the input tensor
204+
at::Tensor flat_input = input.flatten();
205+
at::Tensor flat_out = out.flatten();
206+
207+
// Store casted values to avoid repeated casting
208+
const int32_t zero_point_int32 = static_cast<int32_t>(zero_point);
209+
const float scale_float = static_cast<float>(scale);
210+
211+
for (int i = 0; i < flat_input.numel(); i++) {
212+
double dequantized_value = 0.0;
213+
214+
// Extract quantized value and dequantize based on input dtype
215+
// Following the CPU implementation pattern: (input - zero_point) * scale
216+
if (dtype == at::kByte) {
217+
uint8_t qvalue = flat_input[i].item<uint8_t>();
218+
dequantized_value = (qvalue - zero_point_int32) * scale_float;
219+
} else if (dtype == at::kChar) {
220+
int8_t qvalue = flat_input[i].item<int8_t>();
221+
dequantized_value = (qvalue - zero_point_int32) * scale_float;
222+
} else if (dtype == at::kShort) {
223+
int16_t qvalue = flat_input[i].item<int16_t>();
224+
dequantized_value = (qvalue - zero_point_int32) * scale_float;
225+
} else if (dtype == at::kInt) {
226+
int32_t qvalue = flat_input[i].item<int32_t>();
227+
dequantized_value = (qvalue - zero_point_int32) * scale_float;
228+
} else if (dtype == at::kLong) {
229+
int64_t qvalue = flat_input[i].item<int64_t>();
230+
dequantized_value = (qvalue - zero_point_int32) * scale_float;
231+
}
232+
233+
// Store result based on output dtype
234+
if (out_dtype == at::kFloat) {
235+
flat_out[i] = static_cast<float>(dequantized_value);
236+
} else if (out_dtype == at::kDouble) {
237+
flat_out[i] = dequantized_value;
238+
} else if (out_dtype == at::kHalf) {
239+
flat_out[i] = static_cast<c10::Half>(dequantized_value);
240+
}
241+
}
242+
243+
return out.reshape(input.sizes());
244+
}
245+
246+
// Forward declaration of implementation functions
247+
void test_vulkan_dequantize_per_tensor_impl(
248+
const std::vector<int>& input_sizes,
249+
float scale,
250+
int zero_point,
251+
int64_t quant_min,
252+
int64_t quant_max,
253+
at::ScalarType dtype,
254+
at::ScalarType out_dtype,
255+
const vkcompute::utils::StorageType in_storage,
256+
const vkcompute::utils::StorageType out_storage);
257+
258+
// Wrapper function to test both buffer and texture storage types
259+
void test_vulkan_dequantize_per_tensor(
260+
const std::vector<int>& input_sizes,
261+
float scale,
262+
int zero_point,
263+
int64_t quant_min,
264+
int64_t quant_max,
265+
at::ScalarType dtype,
266+
at::ScalarType out_dtype) {
267+
// Test with buffer storage
268+
test_vulkan_dequantize_per_tensor_impl(
269+
input_sizes,
270+
scale,
271+
zero_point,
272+
quant_min,
273+
quant_max,
274+
dtype,
275+
out_dtype,
276+
vkcompute::utils::kBuffer,
277+
vkcompute::utils::kBuffer);
278+
279+
// Test with texture storage
280+
test_vulkan_dequantize_per_tensor_impl(
281+
input_sizes,
282+
scale,
283+
zero_point,
284+
quant_min,
285+
quant_max,
286+
dtype,
287+
out_dtype,
288+
vkcompute::utils::kTexture3D,
289+
vkcompute::utils::kTexture3D);
290+
}
291+
292+
void test_reference_dequantize_per_tensor(
293+
const std::vector<int>& input_sizes,
294+
float scale,
295+
int zero_point,
296+
int64_t quant_min,
297+
int64_t quant_max,
298+
at::ScalarType dtype,
299+
at::ScalarType out_dtype) {
300+
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
301+
std::vector<int64_t> input_sizes_int64(
302+
input_sizes.begin(), input_sizes.end());
303+
304+
// Create a quantized input tensor with values from quant_min to quant_max
305+
at::Tensor input;
306+
if (dtype == at::kByte) {
307+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
308+
} else if (dtype == at::kChar) {
309+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
310+
} else if (dtype == at::kShort) {
311+
input =
312+
at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
313+
} else if (dtype == at::kInt) {
314+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
315+
} else {
316+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
317+
}
318+
319+
// Fill with a simple pattern: values from quant_min to quant_max in steps
320+
float step = 1.0f;
321+
if (input.numel() > 1) {
322+
step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
323+
}
324+
325+
auto flat_input = input.flatten();
326+
for (int i = 0; i < flat_input.numel(); i++) {
327+
int64_t qvalue = quant_min + i * step;
328+
if (dtype == at::kByte) {
329+
flat_input[i] = static_cast<uint8_t>(qvalue);
330+
} else if (dtype == at::kChar) {
331+
flat_input[i] = static_cast<int8_t>(qvalue);
332+
} else if (dtype == at::kShort) {
333+
flat_input[i] = static_cast<int16_t>(qvalue);
334+
} else if (dtype == at::kInt) {
335+
flat_input[i] = static_cast<int32_t>(qvalue);
336+
} else if (dtype == at::kLong) {
337+
flat_input[i] = static_cast<int64_t>(qvalue);
338+
}
339+
}
340+
341+
// Reshape back to original dimensions
342+
input = flat_input.reshape(input_sizes_int64);
343+
344+
// Get reference output
345+
at::Tensor reference_out = dequantize_per_tensor_reference_impl(
346+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
347+
348+
// Get implementation output
349+
at::Tensor impl_out = torch::executor::native::dequantize_per_tensor_aten(
350+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
351+
352+
// Compare outputs
353+
const bool output_correct = at::allclose(reference_out, impl_out);
354+
if (!output_correct) {
355+
std::cout << "\n"
356+
<< "Failed with parameters: " << std::endl;
357+
std::cout << " scale: " << scale << std::endl;
358+
std::cout << " zero_point: " << zero_point << std::endl;
359+
std::cout << " quant_min: " << quant_min << std::endl;
360+
std::cout << " quant_max: " << quant_max << std::endl;
361+
362+
std::cout << "input:" << std::endl;
363+
std::cout << input << std::endl;
364+
std::cout << "reference:" << std::endl;
365+
std::cout << reference_out << std::endl;
366+
std::cout << "implementation:" << std::endl;
367+
std::cout << impl_out << std::endl;
368+
}
369+
370+
ASSERT_TRUE(output_correct);
371+
}
372+
373+
void test_vulkan_dequantize_per_tensor_impl(
374+
const std::vector<int>& input_sizes,
375+
float scale,
376+
int zero_point,
377+
int64_t quant_min,
378+
int64_t quant_max,
379+
at::ScalarType dtype,
380+
at::ScalarType out_dtype,
381+
const vkcompute::utils::StorageType in_storage,
382+
const vkcompute::utils::StorageType out_storage) {
383+
check_dequantize_args(quant_min, quant_max, dtype, out_dtype);
384+
std::vector<int64_t> input_sizes_int64(
385+
input_sizes.begin(), input_sizes.end());
386+
387+
// Create a quantized input tensor with values from quant_min to quant_max
388+
at::Tensor input;
389+
if (dtype == at::kByte) {
390+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kByte));
391+
} else if (dtype == at::kChar) {
392+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kChar));
393+
} else if (dtype == at::kShort) {
394+
input =
395+
at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kShort));
396+
} else if (dtype == at::kInt) {
397+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kInt));
398+
} else {
399+
input = at::zeros(input_sizes_int64, at::device(at::kCPU).dtype(at::kLong));
400+
}
401+
402+
// Fill with a simple pattern: values from quant_min to quant_max in steps
403+
float step = 1.0f;
404+
if (input.numel() > 1) {
405+
step = static_cast<float>(quant_max - quant_min) / (input.numel() - 1);
406+
}
407+
408+
auto flat_input = input.flatten();
409+
for (int i = 0; i < flat_input.numel(); i++) {
410+
int64_t qvalue = quant_min + i * step;
411+
if (dtype == at::kByte) {
412+
flat_input[i] = static_cast<uint8_t>(qvalue);
413+
} else if (dtype == at::kChar) {
414+
flat_input[i] = static_cast<int8_t>(qvalue);
415+
} else if (dtype == at::kShort) {
416+
flat_input[i] = static_cast<int16_t>(qvalue);
417+
} else if (dtype == at::kInt) {
418+
flat_input[i] = static_cast<int32_t>(qvalue);
419+
} else if (dtype == at::kLong) {
420+
flat_input[i] = static_cast<int64_t>(qvalue);
421+
}
422+
}
423+
424+
// Reshape back to original dimensions
425+
input = flat_input.reshape(input_sizes_int64);
426+
427+
// Get reference output
428+
at::Tensor reference_out =
429+
torch::executor::native::dequantize_per_tensor_aten(
430+
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype);
431+
432+
// Build Vulkan dequantize_per_tensor graph
433+
using namespace vkcompute;
434+
435+
GraphConfig config;
436+
config.set_storage_type_override(in_storage);
437+
ComputeGraph graph(config);
438+
439+
IOValueRef r_input = graph.add_input_tensor(
440+
input.sizes().vec(), from_at_scalartype(dtype), in_storage);
441+
442+
const ValueRef r_scale = graph.add_scalar<double>(scale);
443+
const ValueRef r_zero_point = graph.add_scalar<int64_t>(zero_point);
444+
const ValueRef r_quant_min = graph.add_scalar<int64_t>(quant_min);
445+
const ValueRef r_quant_max = graph.add_scalar<int64_t>(quant_max);
446+
447+
const ValueRef r_out = graph.add_tensor(
448+
input.sizes().vec(), from_at_scalartype(out_dtype), out_storage);
449+
450+
VK_GET_OP_FN("dequantize_per_tensor.default")
451+
(graph,
452+
{
453+
r_input.value,
454+
r_scale,
455+
r_zero_point,
456+
r_quant_min,
457+
r_quant_max,
458+
r_out,
459+
});
460+
461+
ValueRef staging_out = graph.set_output_tensor(r_out);
462+
463+
graph.prepare();
464+
graph.encode_prepack();
465+
graph.prepack();
466+
graph.encode_execute();
467+
468+
// Run Vulkan dequantize_per_tensor
469+
graph.copy_into_staging(
470+
r_input.staging, input.const_data_ptr(), input.numel());
471+
472+
graph.execute();
473+
474+
at::Tensor vk_out = at::empty_like(reference_out).contiguous();
475+
graph.copy_from_staging(
476+
staging_out, vk_out.mutable_data_ptr(), vk_out.numel());
477+
478+
// Compare outputs
479+
const bool output_correct = at::allclose(reference_out, vk_out);
480+
if (!output_correct) {
481+
std::cout << "\n"
482+
<< "Failed with parameters: " << std::endl;
483+
std::cout << " scale: " << scale << std::endl;
484+
std::cout << " zero_point: " << zero_point << std::endl;
485+
std::cout << " quant_min: " << quant_min << std::endl;
486+
std::cout << " quant_max: " << quant_max << std::endl;
487+
std::cout << " storage type: "
488+
<< (in_storage == vkcompute::utils::kBuffer ? "buffer"
489+
: "texture")
490+
<< std::endl;
491+
492+
std::cout << "input:" << std::endl;
493+
std::cout << input << std::endl;
494+
std::cout << "reference:" << std::endl;
495+
std::cout << reference_out << std::endl;
496+
std::cout << "vulkan:" << std::endl;
497+
std::cout << vk_out << std::endl;
498+
}
499+
500+
ASSERT_TRUE(output_correct);
501+
}
502+
503+
// Test cases for dequantize_per_tensor
504+
TEST(
505+
VulkanDequantizePerTensorTest,
506+
test_reference_dequantize_per_tensor_uint8_to_float) {
507+
test_reference_dequantize_per_tensor(
508+
{2, 3, 4}, // input sizes
509+
0.1, // scale
510+
5, // zero_point
511+
0, // quant_min
512+
255, // quant_max
513+
at::kByte, // input dtype
514+
at::kFloat); // output dtype
515+
}
516+
517+
TEST(
518+
VulkanDequantizePerTensorTest,
519+
test_reference_dequantize_per_tensor_int8_to_float) {
520+
test_reference_dequantize_per_tensor(
521+
{3, 4, 5}, // input sizes
522+
0.05, // scale
523+
0, // zero_point
524+
-128, // quant_min
525+
127, // quant_max
526+
at::kChar, // input dtype
527+
at::kFloat); // output dtype
528+
}
529+
530+
TEST(
531+
VulkanDequantizePerTensorTest,
532+
test_reference_dequantize_per_tensor_int32_to_float) {
533+
test_reference_dequantize_per_tensor(
534+
{4, 6, 2}, // input sizes
535+
0.2, // scale
536+
2, // zero_point
537+
std::numeric_limits<int32_t>::min(), // quant_min
538+
std::numeric_limits<int32_t>::max(), // quant_max
539+
at::kInt, // input dtype
540+
at::kFloat); // output dtype
541+
}
542+
543+
TEST(
544+
VulkanDequantizePerTensorTest,
545+
test_reference_dequantize_per_tensor_uint8_to_half) {
546+
test_reference_dequantize_per_tensor(
547+
{7, 4}, // input sizes
548+
0.1, // scale
549+
10, // zero_point
550+
0, // quant_min
551+
255, // quant_max
552+
at::kByte, // input dtype (uint8)
553+
at::kHalf); // output dtype
554+
}
555+
556+
TEST(
557+
VulkanDequantizePerTensorTest,
558+
test_reference_dequantize_per_tensor_int32_to_half) {
559+
test_reference_dequantize_per_tensor(
560+
{2, 6, 5}, // input sizes
561+
0.3, // scale
562+
-10, // zero_point
563+
std::numeric_limits<int32_t>::min(), // quant_min
564+
std::numeric_limits<int32_t>::max(), // quant_max
565+
at::kInt, // input dtype
566+
at::kHalf); // output dtype
567+
}

0 commit comments

Comments
 (0)