diff --git a/core/conversion/converters/impl/quantization.cpp b/core/conversion/converters/impl/quantization.cpp index e8fdc69f84..addf629e6b 100644 --- a/core/conversion/converters/impl/quantization.cpp +++ b/core/conversion/converters/impl/quantization.cpp @@ -11,6 +11,22 @@ namespace { #if NV_TENSORRT_MAJOR > 7 // clang-format off + +bool add_qdq(ConversionCtx *ctx, const torch::jit::Node* n, nvinfer1::ITensor* input, nvinfer1::ITensor* scale, std::string& opName) { + nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale); + TORCHTRT_CHECK(quantize_layer, "Unable to create QuantizeLayer from node: " << *n); + quantize_layer->setAxis(0); + + nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale); + TORCHTRT_CHECK(dequantize_layer, "Unable to create DequantizeLayer from node: " << *n); + dequantize_layer->setAxis(0); + + auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0)); + LOG_DEBUG("[" << opName << "]"<< " Output tensor shape: " << qdq_out->getDimensions()); + + return true; +} + auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() .pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -20,18 +36,16 @@ auto quantization_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns auto scale = args[1].unwrapToScalar().to(); auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale})); // Add and configure a QuantizeLayer. - nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor); - quantize_layer->setAxis(0); - - // Add and configure DequantizeLayer following a QuantizeLayer - nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor); - dequantize_layer->setAxis(0); - - auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0)); - LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions()); - - return true; + std::string opName("aten::fake_quantize_per_tensor_affine"); + return add_qdq(ctx, n, input, scaleTensor, opName); }}) + .pattern({"aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto input = args[0].ITensorOrFreeze(ctx); + auto scale = args[1].ITensorOrFreeze(ctx); + std::string opName("aten::fake_quantize_per_tensor_affine.tensor_qparams"); + return add_qdq(ctx, n, input, scale, opName); + }}) .pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { // This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API. diff --git a/tests/core/conversion/converters/test_quantization.cpp b/tests/core/conversion/converters/test_quantization.cpp index fcbef02e16..d6881bb37e 100644 --- a/tests/core/conversion/converters/test_quantization.cpp +++ b/tests/core/conversion/converters/test_quantization.cpp @@ -30,6 +30,40 @@ TEST(Converters, ATenFakeQuantizePerTensorConvertsCorrectly) { torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } +TEST(Converters, ATenFakeQuantizePerTensorWithParamsConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %22 : int = prim::Constant[value=-128]() + %14 : int = prim::Constant[value=4]() + %9 : None = prim::Constant() + %35 : Device = prim::Constant[value="cuda:0"]() + %6 : int = prim::Constant[value=6]() + %7 : int = prim::Constant[value=3]() + %3 : int = prim::Constant[value=1]() + %5 : float = prim::Constant[value=3.5]() + %13 : int = prim::Constant[value=1]() + %23 : int = prim::Constant[value=127]() + %4 : int[] = prim::ListConstruct(%3) + %11 : Tensor = aten::full(%4, %5, %6, %9, %35, %9) + %12 : int[] = prim::ListConstruct(%3) + %19 : Tensor = aten::full(%12, %13, %7, %9, %35, %9) + %quant_input.1 : Tensor = aten::fake_quantize_per_tensor_affine(%x.1, %11, %19, %22, %23) + return (%quant_input.1))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 5, 5, 5}, {at::kCUDA}).to(at::kFloat); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}, nvinfer1::DataType::kINT8); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +} + TEST(Converters, ATenFakeQuantizePerChannelConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor):