diff --git a/notebooks/qat-ptq-workflow.ipynb b/notebooks/qat-ptq-workflow.ipynb new file mode 100644 index 0000000000..c0e719b3b4 --- /dev/null +++ b/notebooks/qat-ptq-workflow.ipynb @@ -0,0 +1,913 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b971af08", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License.\n", + "# ==============================================================================" + ] + }, + { + "cell_type": "markdown", + "id": "c6384192", + "metadata": {}, + "source": [ + "\n", + "\n", + "# Accelerate Deep Learning Models using Torch-TensorRT " + ] + }, + { + "attachments": { + "img1.JPG": { + "image/jpeg": "" + } + }, + "cell_type": "markdown", + "id": "f5454823", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "Deep Learning has touched almost every industry and has transformed the way industries operate and provide services. We perform or experience real-time analytics all the time around us, for example, an advertisement that you saw while swiping through the stories on Instagram, or the video recommendation that floated on your youtube home screen. To cater to these real-time inferences, deep learning practitioners need to maximise model throughput while having highly accurate predictions. Quantization amongst many techniques employed, let's have a look at how it can be used to accelerate models.\n", + "\n", + "Model Quantization is a popular way of optimization which reduces the size of models thereby accelerating inference, also opening up the possibilities of deployments on devices with lower computation power such as Jetson. Simply put, quantization is a process of mapping input values from a larger set to output values in a smaller set. In the context of deep learning, we often train deep learning models using floating-point 32 bit arithmetic (FP32) as we can take advantage of a wider range of numbers, catering to higher precision. The model data–network parameters and activations–are converted from this floating point representation to a lower precision representation, typically using 8-bit integers (int8). In the case of int8, the range [qmin, qmax] would be [-128, 127].\n", + "\n", + "![img1.JPG](attachment:img1.JPG)\n", + "\n", + "A quick rationale of how quantization's throughput is acheived by the following thought experiment: Imagine the complexity of multiplying 3.999x2.999 and 4x3. The latter is easier to perform than the former. This is the simplicity in calculation seen by quantizing the numbers to lower precision. But the challenge here is that this rounding off can result in a lower accuracy model. To address this loss of accuracy, different quantization techniques have been developed. These techniques can be classified into two categories, post-training quantization (PTQ) and quantization-aware training (QAT).\n", + "\n", + "In this notebook, we illustrate the workflow that you can adopt while quantizing a deep learning model in Torch-TensorRT. The notebook takes you through an example of Mobilenetv2 for a classification task on a subset of Imagenet Dataset called Imagenette which has 10 classes. \n", + "\n", + "1. [Requirements](#1)\n", + "2. [Setup a baseline Mobilenetv2 model](#2)\n", + "3. [Convert to Torch TRT](#3)\n", + "4. [Post Training Quantization (PTQ)](#4)\n", + "4. [Quantization Aware training](#4)\n", + "5. [Export to Torchscript](#5)\n", + "6. [Inference using Torch-TensorRT](#6)\n", + "7. [References](#7)\n", + "\n", + "This notebook is implemented using the NGC pytorch container nvcr.io/nvidia/pytorch:22.04-py3. Follow instructions here https://ngc.nvidia.com/setup/api-key to setup your own API key to use the NGC service through the Docker client. " + ] + }, + { + "cell_type": "markdown", + "id": "06b37d07", + "metadata": {}, + "source": [ + "\n", + "## 1. Requirements\n", + "Please install the required dependencies and import these libraries accordingly" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "0a068b12", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org\n", + "!pip install wget" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4e2e58b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.1.2\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.backends.cudnn as cudnn\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data\n", + "import torchvision.transforms as transforms\n", + "from torchvision import models, datasets\n", + "import torch_tensorrt\n", + "\n", + "\n", + "import pytorch_quantization\n", + "from pytorch_quantization import nn as quant_nn\n", + "from pytorch_quantization import quant_modules\n", + "from pytorch_quantization.tensor_quant import QuantDescriptor\n", + "from pytorch_quantization import calib\n", + "from tqdm import tqdm\n", + "\n", + "print(pytorch_quantization.__version__)\n", + "\n", + "import os\n", + "import sys\n", + "import warnings\n", + "import time\n", + "import numpy as np\n", + "import wget\n", + "import tarfile\n", + "import shutil\n", + "warnings.simplefilter('ignore')" + ] + }, + { + "cell_type": "markdown", + "id": "0575e590", + "metadata": {}, + "source": [ + "\n", + "## 2. Setup a baseline Mobilenetv2 Model" + ] + }, + { + "cell_type": "markdown", + "id": "a83b886f", + "metadata": {}, + "source": [ + "#### Preapring the Dataset\n", + "\n", + "Imagenette is a subset of ImageNet and has 10 classes. The classes are as follows in the order of their labels : 'tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball' and 'parachute'. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "60cead9d", + "metadata": {}, + "outputs": [], + "source": [ + "def download_data(DATA_DIR):\n", + " if os.path.exists(DATA_DIR):\n", + " if not os.path.exists(os.path.join(DATA_DIR, 'imagenette2-320')):\n", + " url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz'\n", + " wget.download(url)\n", + " # open file\n", + " file = tarfile.open('imagenette2-320.tgz')\n", + " # extracting file\n", + " file.extractall(DATA_DIR)\n", + " file.close()\n", + " else:\n", + " print(\"This directory doesn't exist. Create the directory and run again\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "bfb55992", + "metadata": {}, + "outputs": [], + "source": [ + "if not os.path.exists(\"./data\"):\n", + " os.mkdir(\"./data\")\n", + "download_data(\"./data\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "07d1fc63", + "metadata": {}, + "outputs": [], + "source": [ + "# Define main data directory\n", + "DATA_DIR = './data/imagenette2-320' \n", + "# Define training and validation data paths\n", + "TRAIN_DIR = os.path.join(DATA_DIR, 'train') \n", + "VAL_DIR = os.path.join(DATA_DIR, 'val')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "acd3cd99", + "metadata": {}, + "outputs": [], + "source": [ + "#Performing Transformations on the dataset and defining training and validation dataloaders\n", + "transform = transforms.Compose([\n", + " transforms.Resize(256),\n", + " transforms.CenterCrop(224),\n", + " transforms.ToTensor(),\n", + " ])\n", + "train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)\n", + "val_dataset = datasets.ImageFolder(VAL_DIR, transform=transform)\n", + "calib_dataset = torch.utils.data.random_split(val_dataset, [2901, 1024])[1]\n", + "\n", + "train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)\n", + "val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)\n", + "calib_dataloader = data.DataLoader(calib_dataset, batch_size=64, shuffle=False, drop_last=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a2f8914c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0)\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Visualising an image from the validation set\n", + "import matplotlib.pyplot as plt\n", + "for images, labels in val_dataloader:\n", + " print(labels[0])\n", + " image = images[0]\n", + " img = image.swapaxes(0, 1)\n", + " img = img.swapaxes(1, 2)\n", + " plt.imshow(img)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "4b7441e6", + "metadata": {}, + "source": [ + "#### Setting up Mobilenetv2\n", + "\n", + "Mobilenetv2 available in Torchvision is pretrained on the ImageNet that has 1000 classes. The Imagenette dataset has 10 classes. \n", + "We set up this model by freezing the weights excpet for the last classification layer and train only the last classification layer to be able to predict the 10 classes of the dataset. " + ] + }, + { + "cell_type": "markdown", + "id": "b9577f2a", + "metadata": {}, + "source": [ + "*Define the Mobilenetv2 model*" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c29ae7b8", + "metadata": {}, + "outputs": [], + "source": [ + "#This function allows you to set the all the parameters to not have gradients, \n", + "#allowing you to freeze the model and not undergo training during the train step. \n", + "def set_parameter_requires_grad(model, feature_extracting):\n", + " if feature_extracting:\n", + " for param in model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "feature_extract = True #This varaible can be set False if you want to finetune the model by updating all the parameters. \n", + "model = models.mobilenet_v2(pretrained=True)\n", + "set_parameter_requires_grad(model, feature_extract)\n", + "#Define a classification head for 10 classes.\n", + "model.classifier[1] = nn.Linear(1280, 10)\n", + "model = model.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5c03df98", + "metadata": {}, + "outputs": [], + "source": [ + "# Declare Learning rate\n", + "lr = 0.0001\n", + "\n", + "# Use cross entropy loss for classification and SGD optimizer\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.SGD(model.parameters(), lr=lr)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8470846e", + "metadata": {}, + "outputs": [], + "source": [ + "#Define functions for training, evalution, saving checkpoint and train parameter setting function\n", + "def train(model, dataloader, crit, opt, epoch):\n", + " model.train()\n", + " running_loss = 0.0\n", + " for batch, (data, labels) in enumerate(dataloader):\n", + " data, labels = data.cuda(), labels.cuda(non_blocking=True)\n", + " opt.zero_grad()\n", + " out = model(data)\n", + " loss = crit(out, labels)\n", + " loss.backward()\n", + " opt.step()\n", + " running_loss += loss.item()\n", + " if batch % 100 == 99:\n", + " print(\"Batch: [%5d | %5d] loss: %.3f\" % (batch + 1, len(dataloader), running_loss / 100))\n", + " running_loss = 0.0\n", + " \n", + "def evaluate(model, dataloader, crit, epoch):\n", + " total = 0\n", + " correct = 0\n", + " loss = 0.0\n", + " class_probs = []\n", + " class_preds = []\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for data, labels in dataloader:\n", + " data, labels = data.cuda(), labels.cuda(non_blocking=True)\n", + " out = model(data)\n", + " loss += crit(out, labels)\n", + " preds = torch.max(out, 1)[1]\n", + " class_probs.append([F.softmax(i, dim=0) for i in out])\n", + " class_preds.append(preds)\n", + " total += labels.size(0)\n", + " correct += (preds == labels).sum().item()\n", + "\n", + " evaluate_probs = torch.cat([torch.stack(batch) for batch in class_probs])\n", + " evaluate_preds = torch.cat(class_preds)\n", + "\n", + " return loss / total, correct / total\n", + "\n", + "def save_checkpoint(state, ckpt_path=\"checkpoint.pth\"):\n", + " torch.save(state, ckpt_path)\n", + " print(\"Checkpoint saved\")\n", + " \n", + "cudnn.benchmark = True\n", + "# Helper function to benchmark the model\n", + "def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):\n", + " input_data = torch.randn(input_shape)\n", + " input_data = input_data.to(\"cuda\")\n", + " if dtype=='fp16':\n", + " input_data = input_data.half()\n", + " \n", + " with torch.no_grad():\n", + " for _ in range(nwarmup):\n", + " features = model(input_data)\n", + " torch.cuda.synchronize()\n", + "\n", + " timings = []\n", + " with torch.no_grad():\n", + " for i in range(1, nruns+1):\n", + " start_time = time.time()\n", + " output = model(input_data)\n", + " torch.cuda.synchronize()\n", + " end_time = time.time()\n", + " timings.append(end_time - start_time)\n", + " print('Average batch time: %.2f ms'%(np.mean(timings)*1000))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "02a625c9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: [ 1 / 3] LR: 0.000100\n", + "Batch: [ 100 | 295] loss: 2.367\n", + "Batch: [ 200 | 295] loss: 2.271\n", + "Test Loss: 0.03319 Test Acc: 28.71%\n", + "Epoch: [ 2 / 3] LR: 0.000100\n", + "Batch: [ 100 | 295] loss: 2.110\n", + "Batch: [ 200 | 295] loss: 2.023\n", + "Test Loss: 0.02954 Test Acc: 56.79%\n", + "Epoch: [ 3 / 3] LR: 0.000100\n", + "Batch: [ 100 | 295] loss: 1.884\n", + "Batch: [ 200 | 295] loss: 1.816\n", + "Test Loss: 0.02633 Test Acc: 75.56%\n", + "Checkpoint saved\n" + ] + } + ], + "source": [ + "# Train the model for 3 epochs to attain an acceptable accuracy.\n", + "num_epochs=3\n", + "for epoch in range(num_epochs):\n", + " print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))\n", + "\n", + " train(model, train_dataloader, criterion, optimizer, epoch)\n", + " test_loss, test_acc = evaluate(model, val_dataloader, criterion, epoch)\n", + "\n", + " print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n", + " \n", + "save_checkpoint({'epoch': epoch + 1,\n", + " 'model_state_dict': model.state_dict(),\n", + " 'acc': test_acc,\n", + " 'opt_state_dict': optimizer.state_dict()\n", + " },\n", + " ckpt_path=\"mobilenetv2_base_ckpt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "411d0ebc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mobilenetv2 Baseline accuracy: 75.56%\n", + "Average batch time: 11.92 ms\n" + ] + } + ], + "source": [ + "#Evaluate and benchmark the performance of the baseline model\n", + "test_loss, test_acc = evaluate(model, val_dataloader, criterion, 0)\n", + "print(\"Mobilenetv2 Baseline accuracy: {:.2f}%\".format(100 * test_acc))\n", + "\n", + "benchmark(model, input_shape=(64, 3, 224, 224))" + ] + }, + { + "cell_type": "markdown", + "id": "71fdd581", + "metadata": {}, + "source": [ + "\n", + "### Convert to Torch TRT\n", + "\n", + "Torch-TensorRT is a compiler which converts Torchscript graphs into TensorRT helping to levergae the inference optimizations on NVIDIA GPUs. Unlike PyTorch's Just-In-Time (JIT) compiler, Torch-TensorRT is an Ahead-of-Time (AOT) compiler, meaning that before you deploy your TorchScript code, you go through an explicit compile step to convert a standard TorchScript program into a module targeting TensorRT engine. With just one line of code, it provides a simple API that gives a performance speedup on NVIDIA GPUs. " + ] + }, + { + "cell_type": "markdown", + "id": "f75ab9fd", + "metadata": {}, + "source": [ + "Let us convert the above FP32 Mobilenetv2 into a TensorRT engine. Before we do that, we need to first export our model to TorchScript. To learn more about Torchscript, please refer to https://pytorch.org/docs/stable/jit.html. The model will then be evaluated and benchmarked for performance so we can compare these numbers against the quantized int8 model later. " + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "e24451cf", + "metadata": {}, + "outputs": [], + "source": [ + "# Exporting to TorchScript\n", + "with torch.no_grad():\n", + " data = iter(val_dataloader)\n", + " images, _ = data.next()\n", + " jit_model = torch.jit.trace(model, images.to(\"cuda\"))\n", + " torch.jit.save(jit_model, \"mobilenetv2_base.jit.pt\")\n", + "\n", + "#Loading the Torchscript model and compiling it into a TensorRT model\n", + "baseline_model = torch.jit.load(\"mobilenetv2_base.jit.pt\").eval()\n", + "compile_spec = {\"inputs\": [torch_tensorrt.Input([64, 3, 224, 224])]\n", + " , \"enabled_precisions\": torch.float\n", + " }\n", + "trt_base = torch_tensorrt.compile(baseline_model, **compile_spec)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f3fd416f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mobilenetv2 TRT Baseline accuracy: 75.59%\n", + "Average batch time: 6.78 ms\n" + ] + } + ], + "source": [ + "# Evaluate and benchmark the performance of the baseline TRT model (TRT FP32 Model)\n", + "test_loss, test_acc = evaluate(trt_base, val_dataloader, criterion, 0)\n", + "print(\"Mobilenetv2 TRT Baseline accuracy: {:.2f}%\".format(100 * test_acc))\n", + "\n", + "benchmark(trt_base, input_shape=(64, 3, 224, 224))" + ] + }, + { + "cell_type": "markdown", + "id": "0a079b97", + "metadata": {}, + "source": [ + "\n", + "## 4. Post Training Quantization (PTQ)" + ] + }, + { + "attachments": { + "img4.JPG": { + "image/jpeg": "" + } + }, + "cell_type": "markdown", + "id": "bf3d4397", + "metadata": {}, + "source": [ + "As the name suggests, PTQ is the technique that is performed on a trained model that has achieved acceptable accuracy. It is effective and also quick to implement because it does not require any retraining of the network. Now that we have the trained checkpoint ready, let's start quantizing the model. \n", + "\n", + "To now perform an optimized PTQ technique, we often use calibration by passing calibration data, a subset of training or validation data to determine the range of representable FP32 values to be quantized. This gives us the scale that can be used to map the values to the quantized range. We call this process of choosing the input range \"Calibration\". The three popular techniques used to calibrate are:\n", + "\n", + "- Min-Max: Use the minimum and maximum of the FP32 values seen during calibration. The disadvantage with this method is that, if there is an outlier, our mapping can induce a larger rounding error. \n", + "\n", + "- Entropy: Not all values in the FP32 tensor may be equally important. Hence using cross entropy with different range values [T1, T2], we try to minimize the information loss between the original FP32 tensor and quantized tensor. \n", + "\n", + "- Percentile: Use the percentile of the distribution of absolute values seen during calibration. Say, at 99% calibration, we clip 1% of the largest magnitude values, and determine [P1, P2] as the representable range to be quantized\n", + "\n", + "\n", + "![img4.JPG](attachment:img4.JPG)\n", + "\n", + "\n", + "Torch-TensorRT Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators.\n", + "Here we use `min-max` calibration technique" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "06744b5d", + "metadata": {}, + "outputs": [], + "source": [ + "calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(calib_dataloader,\n", + " use_cache=False,\n", + " algo_type=torch_tensorrt.ptq.CalibrationAlgo.MINMAX_CALIBRATION,\n", + " device=torch.device('cuda:0'))\n", + "\n", + "compile_spec = {\n", + " \"inputs\": [torch_tensorrt.Input([64, 3, 224, 224])],\n", + " \"enabled_precisions\": torch.int8,\n", + " \"calibrator\": calibrator,\n", + " \"truncate_long_and_double\": True\n", + " \n", + " }\n", + "trt_ptq = torch_tensorrt.compile(baseline_model, **compile_spec)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a5ec3a81", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mobilenetv2 PTQ accuracy: 71.41%\n", + "Average batch time: 1.57 ms\n" + ] + } + ], + "source": [ + "# Evaluate the PTQ model\n", + "test_loss, test_acc = evaluate(trt_ptq, val_dataloader, criterion, 0)\n", + "print(\"Mobilenetv2 PTQ accuracy: {:.2f}%\".format(100 * test_acc))\n", + "\n", + "benchmark(trt_ptq, input_shape=(64, 3, 224, 224))" + ] + }, + { + "attachments": { + "img5.JPG": { + "image/jpeg": "" + } + }, + "cell_type": "markdown", + "id": "d3e676e7", + "metadata": {}, + "source": [ + "\n", + "## 4. Quantization Aware Training\n", + "\n", + "PTQ resulted in a ~4% accuracy drop. After PTQ is performed, sometimes the model may perform poorly by not retaining the accuracy as the process is not able to mitigate the large quantization error induced by low-bit quantization. This could happen if there are sensitive layers in the network, like the Depth wise convolutional networks, in MobileNets which are more susceptible to producing larger quantization error. \n", + "\n", + "This is when we might want to consider using QAT. The idea behind QAT is simple: you can improve the lost accuracy of the quantized model, if you had trained the model with quantization error. There are many ways of doing this, starting the training of the model from scratch or fine-tuning a pre-trained model. Whatever method you choose, the quantization error is induced in the training loss by inserting fake-quantization operations. The operation is called “fake” because we quantize the data and immediately perform a dequantize operation producing an approximate version of the data where both input and output still remain as floating point values. We are here trying to simulate the effects of quantization without changing much in the model. \n", + "In the forward-pass, we fake-quantize the weights and activations and use these fake-quantized outputs to perform the layer operations.\n", + "\n", + "![img5.JPG](attachment:img5.JPG)\n", + "\n", + "In the backward pass, while calculating gradient, the quantization operation’s derivative is undefined at the step boundaries, and zero everywhere else. To handle this, QAT uses Straight-through Estimator by approximating the derivative to be 1 for inputs in the representable range. This estimator is essentially letting gradients pass as is through this operator in the backward pass. When the QAT process is done, the scales that were used to quantize the weights and activations are stored in the model and can be used for inference. \n", + "\n", + "We will be using the Pytorch Quantization toolkit, a toolkit built for training and evaluating PyTorch Models with simulated quantization. \n", + "\n", + "`quant_modules.initialize()` will ensure quantized modules are called instead of original modules. For example, when you define a model with convolution, linear snd pooling layers, you will make a call to `QuantConv2d`, `QuantLinear` and `QuantPooling` respectively. `QuantConv2d` basically wraps quantizer nodes around inputs and weights of regular `Conv2d`. Please refer to all the quantized modules in pytorch-quantization toolkit for more information. " + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f1520afc", + "metadata": {}, + "outputs": [], + "source": [ + "quant_modules.initialize()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ee09402f", + "metadata": {}, + "outputs": [], + "source": [ + "# We define Mobilenetv2 again just like we did above\n", + "# All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()\n", + "feature_extract = False\n", + "q_model = models.mobilenet_v2(pretrained=True)\n", + "set_parameter_requires_grad(q_model, feature_extract)\n", + "q_model.classifier[1] = nn.Linear(1280, 10)\n", + "q_model = q_model.cuda()\n", + "\n", + "# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.\n", + "ckpt = torch.load(\"./mobilenetv2_base_ckpt\")\n", + "modified_state_dict={}\n", + "for key, val in ckpt[\"model_state_dict\"].items():\n", + " # Remove 'module.' from the key names\n", + " if key.startswith('module'):\n", + " modified_state_dict[key[7:]] = val\n", + " else:\n", + " modified_state_dict[key] = val\n", + "\n", + "# Load the pre-trained checkpoint\n", + "q_model.load_state_dict(modified_state_dict)\n", + "optimizer.load_state_dict(ckpt[\"opt_state_dict\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b8726956", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_amax(model, **kwargs):\n", + " # Load calib result\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " if isinstance(module._calibrator, calib.MaxCalibrator):\n", + " module.load_calib_amax()\n", + " else:\n", + " module.load_calib_amax(**kwargs)\n", + " model.cuda()\n", + "\n", + "def collect_stats(model, data_loader, num_batches):\n", + " \"\"\"Feed data to the network and collect statistics\"\"\"\n", + " # Enable calibrators\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " module.disable_quant()\n", + " module.enable_calib()\n", + " else:\n", + " module.disable()\n", + "\n", + " # Feed data to the network for collecting stats\n", + " for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):\n", + " model(image.cuda())\n", + " if i >= num_batches:\n", + " break\n", + "\n", + " # Disable calibrators\n", + " for name, module in model.named_modules():\n", + " if isinstance(module, quant_nn.TensorQuantizer):\n", + " if module._calibrator is not None:\n", + " module.enable_quant()\n", + " module.disable_calib()\n", + " else:\n", + " module.enable()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "da627181", + "metadata": {}, + "outputs": [], + "source": [ + "#Calibrate the model using percentile calibration technique.\n", + "with torch.no_grad():\n", + " collect_stats(q_model, train_dataloader, num_batches=32)\n", + " compute_amax(q_model, method=\"max\")" + ] + }, + { + "cell_type": "markdown", + "id": "bcc10e0f", + "metadata": {}, + "source": [ + "Usually the finetuning of QAT model should be quick compared to the full training of the original model. For this Mobilenetv2 model, it is enough to finetune for 2 epochs to get acceptable accuracy. \n", + "\n", + "tensor_quant function in `pytorch_quantization` toolkit is responsible for the above tensor quantization. Usually, per channel quantization is recommended for weights, while per tensor quantization is recommended for activations in a network.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "dc144132", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: [ 1 / 2] LR: 0.001000\n", + "Batch: [ 100 | 295] loss: 1.762\n", + "Batch: [ 200 | 295] loss: 1.771\n", + "Test Loss: 0.02704 Test Acc: 73.74%\n", + "Epoch: [ 2 / 2] LR: 0.001000\n", + "Batch: [ 100 | 295] loss: 1.769\n", + "Batch: [ 200 | 295] loss: 1.774\n", + "Test Loss: 0.02702 Test Acc: 73.54%\n", + "Checkpoint saved\n" + ] + } + ], + "source": [ + "# Finetune the QAT model for 2 epochs\n", + "num_epochs=2\n", + "lr = 0.001\n", + "for epoch in range(num_epochs):\n", + " print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))\n", + "\n", + " train(q_model, train_dataloader, criterion, optimizer, epoch)\n", + " test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)\n", + "\n", + " print(\"Test Loss: {:.5f} Test Acc: {:.2f}%\".format(test_loss, 100 * test_acc))\n", + " \n", + "save_checkpoint({'epoch': epoch + 1,\n", + " 'model_state_dict': q_model.state_dict(),\n", + " 'acc': test_acc,\n", + " 'opt_state_dict': optimizer.state_dict()\n", + " },\n", + " ckpt_path=\"mobilenetv2_qat_ckpt\")" + ] + }, + { + "cell_type": "markdown", + "id": "70bdaeed", + "metadata": {}, + "source": [ + "As you can see, accuracy recovered by ~2%. Fine-tuning for more epochs with learning rate annealing can improve accuracy further. It should be noted that the same fine-tuning schedule will improve the accuracy of the unquantized model as well. Please refer to Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT for detailed recommendations.\n", + "\n", + "During inference, we use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` to perform quantization as this is easier to convert into corresponding TensorRT operators. \n", + "\n", + "Let us now prepare this model to export it into TorchScript. Setting `quant_nn.TensorQuantizer.use_fb_fake_quant = True` enables the QAT model to use `torch.fake_quantize_per_tensor_affine` and `torch.fake_quantize_per_channel_affine` operators instead of `tensor_quant` function to export quantization operators. In torchscript, they are represented as `aten::fake_quantize_per_tensor_affine` and `aten::fake_quantize_per_channel_affine`. " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "a234bb71", + "metadata": {}, + "outputs": [], + "source": [ + "quant_nn.TensorQuantizer.use_fb_fake_quant = True\n", + "with torch.no_grad():\n", + " data = iter(val_dataloader)\n", + " images, _ = data.next()\n", + " jit_model = torch.jit.trace(q_model, images.to(\"cuda\"))\n", + " torch.jit.save(jit_model, \"mobilenetv2_qat.jit.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "13c5b2c9", + "metadata": {}, + "outputs": [], + "source": [ + "#Loading the Torchscript model and compiling it into a TensorRT model\n", + "qat_model = torch.jit.load(\"mobilenetv2_qat.jit.pt\").eval()\n", + "compile_spec = {\"inputs\": [torch_tensorrt.Input([64, 3, 224, 224])],\n", + " \"enabled_precisions\": torch.int8\n", + " }\n", + "trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "b465044d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mobilenetv2 QAT accuracy using TensorRT: 74.00%\n", + "Average batch time: 2.18 ms\n" + ] + } + ], + "source": [ + "#Evaluate and benchmark the performance of the QAT-TRT model (TRT INT8)\n", + "test_loss, test_acc = evaluate(trt_mod, val_dataloader, criterion, 0)\n", + "print(\"Mobilenetv2 QAT accuracy using TensorRT: {:.2f}%\".format(100 * test_acc))\n", + "benchmark(trt_mod, input_shape=(64, 3, 224, 224))" + ] + }, + { + "cell_type": "markdown", + "id": "20c82807", + "metadata": {}, + "source": [ + "Compared to the TRT FP32 model, we observe a speedup of ~3x with only a ~1.6% loss in accuracy. " + ] + }, + { + "cell_type": "markdown", + "id": "52f311fb", + "metadata": {}, + "source": [ + "### Conclusion\n", + "We put together all the observations that were made in this notebook. Note that, these numbers can vary with every run due to the stochastic nature of the training process, but a similar pattern can still be noticed.\n", + "\n", + "| Model | Accuracy | Performance |\n", + "| ------------------------ | -------- | ----------- |\n", + "| Baseline MobileNetv2 | 75.56% | 11.92ms |\n", + "| Base + TRT
(TRT FP32) | 75.59% | 6.78ms |\n", + "| PTQ + TRT
(TRT int8) | 71.41% | 1.57ms |\n", + "| QAT+TRT
(TRT INT8) | 74.00% | 2.18ms |" + ] + }, + { + "cell_type": "markdown", + "id": "91dfc2c1", + "metadata": {}, + "source": [ + "\n", + "## 7. References\n", + "* Very Deep Convolution Networks for large scale Image Recognition\n", + "* Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT\n", + "* Pytorch-quantization toolkit from NVIDIA\n", + "* Pytorch quantization toolkit userguide\n", + "* Quantization basics" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + }, + "vscode": { + "interpreter": { + "hash": "b8290132a159428f0004735847c0b4016c8a5153e62fd80cc71ad5cd485f05b0" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}