|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | +from enum import Enum |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | +from executorch.examples.models.llama2.llama_transformer import KVCache |
| 13 | +from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 |
| 14 | + |
| 15 | + |
| 16 | +""" |
| 17 | + Heavily "inspired" by AO's implementation of the same in torchao/_models/llama/model.py |
| 18 | +""" |
| 19 | + |
| 20 | + |
| 21 | +# Doesnt have to abide by affine quantizaiton laws |
| 22 | +# However, if we do implement quantized sdpa, then this might be handy |
| 23 | +class QuantizedCacheType(Enum): |
| 24 | + AffineSymmetric = 0 |
| 25 | + AffineAsymmetric = 1 |
| 26 | + AffineSymmetricGroupWise = 2 |
| 27 | + AffineAsymmetricGroupWise = 3 |
| 28 | + |
| 29 | + |
| 30 | +class QuantizedKVCache(nn.Module): |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + max_batch_size, |
| 34 | + max_seq_length, |
| 35 | + n_heads, |
| 36 | + head_dim, |
| 37 | + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, |
| 38 | + tranposed=False, |
| 39 | + enable_dynamic_shape=False, |
| 40 | + ): |
| 41 | + super().__init__() |
| 42 | + if cache_type not in ( |
| 43 | + QuantizedCacheType.AffineSymmetric, |
| 44 | + QuantizedCacheType.AffineAsymmetric, |
| 45 | + ): |
| 46 | + |
| 47 | + raise ValueError( |
| 48 | + f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}" |
| 49 | + ) |
| 50 | + # For now supporting int8 only |
| 51 | + self.quantized_cache_dtype = torch.int8 |
| 52 | + self.cache_fp_type = torch.float32 |
| 53 | + self.is_transposed = tranposed |
| 54 | + self.enable_dynamic_shape = enable_dynamic_shape |
| 55 | + if self.is_transposed: |
| 56 | + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) |
| 57 | + scale_shape = (max_batch_size, n_heads, max_seq_length, 1) |
| 58 | + else: |
| 59 | + cache_shape = (max_batch_size, max_seq_length, n_heads, head_dim) |
| 60 | + scale_shape = (max_batch_size, max_seq_length, n_heads, 1) |
| 61 | + self.register_buffer( |
| 62 | + "k_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) |
| 63 | + ) |
| 64 | + self.register_buffer( |
| 65 | + "v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype) |
| 66 | + ) |
| 67 | + self.register_buffer( |
| 68 | + "k_cache_scales", torch.ones(scale_shape, dtype=torch.double) |
| 69 | + ) |
| 70 | + self.register_buffer( |
| 71 | + "v_cache_scales", torch.ones(scale_shape, dtype=torch.double) |
| 72 | + ) |
| 73 | + if cache_type == QuantizedCacheType.AffineAsymmetric: |
| 74 | + self.register_buffer( |
| 75 | + "k_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) |
| 76 | + ) |
| 77 | + self.register_buffer( |
| 78 | + "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int64) |
| 79 | + ) |
| 80 | + |
| 81 | + def _quantize(self, value): |
| 82 | + scales, zero_points = ( |
| 83 | + torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default( |
| 84 | + value, self.quantized_cache_dtype |
| 85 | + ) |
| 86 | + ) |
| 87 | + quantized_value = torch.ops.quantized_decomposed.quantize_per_token( |
| 88 | + value, |
| 89 | + scales, |
| 90 | + zero_points, |
| 91 | + torch.iinfo(self.quantized_cache_dtype).min, |
| 92 | + torch.iinfo(self.quantized_cache_dtype).max, |
| 93 | + self.quantized_cache_dtype, |
| 94 | + ) |
| 95 | + return quantized_value, scales, zero_points |
| 96 | + |
| 97 | + def update(self, input_pos, k_val, v_val): |
| 98 | + # quantize current k_val and store it in the cache |
| 99 | + quantized_k_val, k_scales, k_zero_points = self._quantize(k_val) |
| 100 | + |
| 101 | + quantized_v_val, v_scales, v_zero_points = self._quantize(v_val) |
| 102 | + |
| 103 | + if self.enable_dynamic_shape: |
| 104 | + start_pos = input_pos[0].item() |
| 105 | + torch._check_is_size(start_pos) |
| 106 | + dim_to_slice = 2 if self.is_transposed else 1 |
| 107 | + torch._check(start_pos < self.k_cache.size(dim_to_slice)) |
| 108 | + seq_length = k_val.size(dim_to_slice) |
| 109 | + narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length) |
| 110 | + narrowed_k_scales = self.k_cache_scales.narrow( |
| 111 | + dim_to_slice, start_pos, seq_length |
| 112 | + ) |
| 113 | + narrowed_k_zp = self.k_cache_zero_points.narrow( |
| 114 | + dim_to_slice, start_pos, seq_length |
| 115 | + ) |
| 116 | + narrowed_k.copy_(quantized_k_val) |
| 117 | + narrowed_k_scales.copy_(k_scales) |
| 118 | + narrowed_k_zp.copy_(k_zero_points) |
| 119 | + # pyre-ignore: Incompatible parameter type [6] |
| 120 | + narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length) |
| 121 | + narrowed_v_scales = self.v_cache_scales.narrow( |
| 122 | + dim_to_slice, start_pos, seq_length |
| 123 | + ) |
| 124 | + narrowed_v_zp = self.v_cache_zero_points.narrow( |
| 125 | + dim_to_slice, start_pos, seq_length |
| 126 | + ) |
| 127 | + narrowed_v.copy_(quantized_v_val) |
| 128 | + narrowed_v_scales.copy_(v_scales) |
| 129 | + narrowed_v_zp.copy_(v_zero_points) |
| 130 | + else: |
| 131 | + if self.is_transposed: |
| 132 | + self.k_cache[:, :, input_pos] = quantized_k_val |
| 133 | + self.k_cache_scales[:, :, input_pos] = k_scales |
| 134 | + self.k_cache_zero_points[:, :, input_pos] = k_zero_points |
| 135 | + self.v_cache[:, :, input_pos] = quantized_v_val |
| 136 | + self.v_cache_scales[:, :, input_pos] = v_scales |
| 137 | + self.v_cache_zero_points[:, :, input_pos] = v_zero_points |
| 138 | + else: |
| 139 | + self.k_cache[:, input_pos] = quantized_k_val |
| 140 | + self.k_cache_scales[:, input_pos] = k_scales |
| 141 | + self.k_cache_zero_points[:, input_pos] = k_zero_points |
| 142 | + self.v_cache[:, input_pos] = quantized_v_val |
| 143 | + self.v_cache_scales[:, input_pos] = v_scales |
| 144 | + self.v_cache_zero_points[:, input_pos] = v_zero_points |
| 145 | + |
| 146 | + k_out = torch.ops.quantized_decomposed.dequantize_per_token( |
| 147 | + self.k_cache, |
| 148 | + self.k_cache_scales, |
| 149 | + self.k_cache_zero_points, |
| 150 | + torch.iinfo(self.quantized_cache_dtype).min, |
| 151 | + torch.iinfo(self.quantized_cache_dtype).max, |
| 152 | + self.quantized_cache_dtype, |
| 153 | + self.cache_fp_type, |
| 154 | + ) |
| 155 | + v_out = torch.ops.quantized_decomposed.dequantize_per_token( |
| 156 | + self.v_cache, |
| 157 | + self.v_cache_scales, |
| 158 | + self.v_cache_zero_points, |
| 159 | + torch.iinfo(self.quantized_cache_dtype).min, |
| 160 | + torch.iinfo(self.quantized_cache_dtype).max, |
| 161 | + self.quantized_cache_dtype, |
| 162 | + self.cache_fp_type, |
| 163 | + ) |
| 164 | + return k_out, v_out |
| 165 | + |
| 166 | + @classmethod |
| 167 | + def from_float(cls, kv_cache, cache_type: QuantizedCacheType): |
| 168 | + cache_shape = kv_cache.k_cache.shape |
| 169 | + if kv_cache.is_tranposed: |
| 170 | + max_batch_size, n_heads, max_seq_length, head_dim = cache_shape |
| 171 | + else: |
| 172 | + max_batch_size, max_seq_length, n_heads, head_dim = cache_shape |
| 173 | + return cls( |
| 174 | + max_batch_size, |
| 175 | + max_seq_length, |
| 176 | + n_heads, |
| 177 | + head_dim, |
| 178 | + cache_type, |
| 179 | + kv_cache.is_tranposed, |
| 180 | + kv_cache.enable_dynamic_shape, |
| 181 | + ) |
| 182 | + |
| 183 | + |
| 184 | +def replace_kv_cache_with_quantized_kv_cache(module): |
| 185 | + logging.warning( |
| 186 | + "Replacing KVCache with QuantizedKVCache. This modifies the model in place." |
| 187 | + ) |
| 188 | + for name, child in module.named_children(): |
| 189 | + if isinstance(child, KVCache): |
| 190 | + setattr( |
| 191 | + module, |
| 192 | + name, |
| 193 | + QuantizedKVCache.from_float(child, QuantizedCacheType.AffineAsymmetric), |
| 194 | + ) |
| 195 | + else: |
| 196 | + replace_kv_cache_with_quantized_kv_cache(child) |
| 197 | + return module |
0 commit comments