|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import copy
|
8 |
| -from typing import Callable, List, Tuple |
| 8 | +from typing import Callable, List, Optional, Tuple |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | from executorch.exir.dialects._ops import bind_pattern_to_op, ops as exir_ops
|
|
15 | 15 | )
|
16 | 16 | from torch import fx
|
17 | 17 | from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
|
| 18 | +from torch.library import impl, impl_abstract |
18 | 19 |
|
19 | 20 |
|
20 | 21 | __all__ = [
|
|
31 | 32 |
|
32 | 33 | quantized_decomposed_lib.define(
|
33 | 34 | "embedding_byte.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
|
| 35 | + "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None) -> Tensor", |
| 36 | +) |
| 37 | + |
| 38 | +quantized_decomposed_lib.define( |
| 39 | + "embedding_byte.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
| 40 | + "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", |
| 41 | +) |
| 42 | + |
| 43 | +quantized_decomposed_lib.define( |
| 44 | + "embedding_byte.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
| 45 | + "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", |
| 46 | +) |
| 47 | + |
| 48 | + |
| 49 | +def embedding_weight_checks(weight, weight_scales, weight_zero_points): |
| 50 | + assert weight.dtype in [ |
| 51 | + torch.int8, |
| 52 | + torch.uint8, |
| 53 | + ], f"Expecting weights to be of dtype in [torch.int8, torch.uint8], but got {weight.dtype}" |
| 54 | + assert ( |
| 55 | + weight.dim() == 2 |
| 56 | + ), f"Expecting weight tensor to have dim()==2, but found {weight.dim()}" |
| 57 | + |
| 58 | + assert weight_scales.dtype in [ |
| 59 | + torch.float16, |
| 60 | + torch.float32, |
| 61 | + ], f"Expecting weight_scales to be of dtype in [torch.float16, torch.float32], but got {weight_scales.dtype}" |
| 62 | + assert ( |
| 63 | + weight_scales.dim() == 1 or weight_scales.dim() == 2 |
| 64 | + ), f"Expecting weight_scales tensor to have rank 1 or 2, but found {weight_scales.dim()}" |
| 65 | + assert weight_scales.size(0) == weight.size( |
| 66 | + 0 |
| 67 | + ), f"Expecting weight and scale tensor to have same number of rows, but found {weight.size()} and {weight_scales.size()}" |
| 68 | + |
| 69 | + assert ( |
| 70 | + weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype |
| 71 | + ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" |
| 72 | + assert ( |
| 73 | + weight_zero_points is None or weight_zero_points.dim() == 1 |
| 74 | + ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" |
| 75 | + assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( |
| 76 | + 0 |
| 77 | + ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" |
| 78 | + |
| 79 | + |
| 80 | +@impl(quantized_decomposed_lib, "embedding_byte", "CompositeExplicitAutograd") |
| 81 | +def embedding_byte( |
| 82 | + weight: torch.Tensor, |
| 83 | + weight_scales: torch.Tensor, |
| 84 | + weight_zero_points: Optional[torch.Tensor], |
| 85 | + weight_quant_min: int, |
| 86 | + weight_quant_max: int, |
| 87 | + indices: torch.Tensor, |
| 88 | +) -> torch.Tensor: |
| 89 | + embedding_weight_checks(weight, weight_scales, weight_zero_points) |
| 90 | + group_size = weight.size(1) // ( |
| 91 | + weight_scales.size(1) if weight_scales.dim() == 2 else 1 |
| 92 | + ) |
| 93 | + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( |
| 94 | + weight, |
| 95 | + weight_scales, |
| 96 | + weight_zero_points, |
| 97 | + weight_quant_min, |
| 98 | + weight_quant_max, |
| 99 | + weight.dtype, |
| 100 | + group_size, |
| 101 | + weight_scales.dtype, |
| 102 | + ) |
| 103 | + return torch.ops.aten.embedding.default(weight, indices) |
| 104 | + |
| 105 | + |
| 106 | +@impl_abstract("quantized_decomposed::embedding_byte.out") |
| 107 | +def embedding_byte_out_meta( |
| 108 | + weight: torch.Tensor, |
| 109 | + weight_scales: torch.Tensor, |
| 110 | + weight_zero_points: Optional[torch.Tensor], |
| 111 | + weight_quant_min: int, |
| 112 | + weight_quant_max: int, |
| 113 | + indices: torch.Tensor, |
| 114 | + out: torch.Tensor, |
| 115 | +) -> torch.Tensor: |
| 116 | + return embedding_byte( |
| 117 | + weight, |
| 118 | + weight_scales, |
| 119 | + weight_zero_points, |
| 120 | + weight_quant_min, |
| 121 | + weight_quant_max, |
| 122 | + indices, |
| 123 | + ) |
| 124 | + |
| 125 | + |
| 126 | +@impl(quantized_decomposed_lib, "embedding_byte.dtype", "CompositeExplicitAutograd") |
| 127 | +def embedding_byte_dtype( |
| 128 | + weight: torch.Tensor, |
| 129 | + weight_scales: torch.Tensor, |
| 130 | + weight_zero_points: Optional[torch.Tensor], |
| 131 | + weight_quant_min: int, |
| 132 | + weight_quant_max: int, |
| 133 | + indices: torch.Tensor, |
| 134 | + dtype: Optional[torch.dtype], |
| 135 | +) -> torch.Tensor: |
| 136 | + embedding_weight_checks(weight, weight_scales, weight_zero_points) |
| 137 | + group_size = weight.size(1) // ( |
| 138 | + weight_scales.size(1) if weight_scales.dim() == 2 else 1 |
| 139 | + ) |
| 140 | + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( |
| 141 | + weight, |
| 142 | + weight_scales, |
| 143 | + weight_zero_points, |
| 144 | + weight_quant_min, |
| 145 | + weight_quant_max, |
| 146 | + weight.dtype, |
| 147 | + group_size, |
| 148 | + dtype, |
| 149 | + ) |
| 150 | + return torch.ops.aten.embedding.default(weight, indices) |
| 151 | + |
| 152 | + |
| 153 | +@impl_abstract("quantized_decomposed::embedding_byte.dtype_out") |
| 154 | +def embedding_byte_dtype_out_meta( |
| 155 | + weight: torch.Tensor, |
| 156 | + weight_scales: torch.Tensor, |
| 157 | + weight_zero_points: Optional[torch.Tensor], |
| 158 | + weight_quant_min: int, |
| 159 | + weight_quant_max: int, |
| 160 | + indices: torch.Tensor, |
| 161 | + dtype: Optional[torch.dtype], |
| 162 | + out: torch.Tensor, |
| 163 | +) -> torch.Tensor: |
| 164 | + return embedding_byte_dtype( |
| 165 | + weight, |
| 166 | + weight_scales, |
| 167 | + weight_zero_points, |
| 168 | + weight_quant_min, |
| 169 | + weight_quant_max, |
| 170 | + indices, |
| 171 | + dtype, |
| 172 | + ) |
| 173 | + |
| 174 | + |
| 175 | +quantized_decomposed_lib.define( |
| 176 | + "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
| 177 | + "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor", |
| 178 | +) |
| 179 | + |
| 180 | +quantized_decomposed_lib.define( |
| 181 | + "embedding_4bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
34 | 182 | "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor",
|
35 | 183 | )
|
36 | 184 |
|
| 185 | +quantized_decomposed_lib.define( |
| 186 | + "embedding_4bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
| 187 | + "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)", |
| 188 | +) |
| 189 | + |
| 190 | +quantized_decomposed_lib.define( |
| 191 | + "embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, " |
| 192 | + "int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", |
| 193 | +) |
| 194 | + |
| 195 | + |
| 196 | +@impl(quantized_decomposed_lib, "embedding_4bit", "CompositeExplicitAutograd") |
| 197 | +def embedding_4bit( |
| 198 | + weight: torch.Tensor, |
| 199 | + weight_scales: torch.Tensor, |
| 200 | + weight_zero_points: Optional[torch.Tensor], |
| 201 | + weight_quant_min: int, |
| 202 | + weight_quant_max: int, |
| 203 | + indices: torch.Tensor, |
| 204 | +) -> torch.Tensor: |
| 205 | + embedding_weight_checks(weight, weight_scales, weight_zero_points) |
| 206 | + group_size = (2 * weight.size(1)) // ( |
| 207 | + weight_scales.size(1) if weight_scales.dim() == 2 else 1 |
| 208 | + ) |
| 209 | + weight_even = weight.div(16, rounding_mode="trunc") |
| 210 | + weight_odd = weight.remainder(16) |
| 211 | + weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) |
| 212 | + weight = weight_unpacked.view(weight.shape[0], -1) |
| 213 | + weight = weight.view(torch.int8).add(-8) |
| 214 | + |
| 215 | + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( |
| 216 | + weight, |
| 217 | + weight_scales, |
| 218 | + weight_zero_points, |
| 219 | + weight_quant_min, |
| 220 | + weight_quant_max, |
| 221 | + weight.dtype, |
| 222 | + group_size, |
| 223 | + weight_scales.dtype, |
| 224 | + ) |
| 225 | + return torch.ops.aten.embedding.default(weight, indices) |
| 226 | + |
| 227 | + |
| 228 | +@impl_abstract("quantized_decomposed::embedding_4bit.out") |
| 229 | +def embedding_4bit_out_meta( |
| 230 | + weight: torch.Tensor, |
| 231 | + weight_scales: torch.Tensor, |
| 232 | + weight_zero_points: Optional[torch.Tensor], |
| 233 | + weight_quant_min: int, |
| 234 | + weight_quant_max: int, |
| 235 | + indices: torch.Tensor, |
| 236 | + out: torch.Tensor, |
| 237 | +) -> torch.Tensor: |
| 238 | + return embedding_4bit( |
| 239 | + weight, |
| 240 | + weight_scales, |
| 241 | + weight_zero_points, |
| 242 | + weight_quant_min, |
| 243 | + weight_quant_max, |
| 244 | + indices, |
| 245 | + ) |
| 246 | + |
| 247 | + |
| 248 | +@impl(quantized_decomposed_lib, "embedding_4bit.dtype", "CompositeExplicitAutograd") |
| 249 | +def embedding_4bit_dtype( |
| 250 | + weight: torch.Tensor, |
| 251 | + weight_scales: torch.Tensor, |
| 252 | + weight_zero_points: Optional[torch.Tensor], |
| 253 | + weight_quant_min: int, |
| 254 | + weight_quant_max: int, |
| 255 | + indices: torch.Tensor, |
| 256 | + dtype: Optional[torch.dtype], |
| 257 | +) -> torch.Tensor: |
| 258 | + embedding_weight_checks(weight, weight_scales, weight_zero_points) |
| 259 | + group_size = (2 * weight.size(1)) // ( |
| 260 | + weight_scales.size(1) if weight_scales.dim() == 2 else 1 |
| 261 | + ) |
| 262 | + weight_even = weight.div(16, rounding_mode="trunc") |
| 263 | + weight_odd = weight.remainder(16) |
| 264 | + weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) |
| 265 | + weight = weight_unpacked.view(weight.shape[0], -1) |
| 266 | + weight = weight.view(torch.int8).add(-8) |
| 267 | + |
| 268 | + weight = torch.ops.quantized_decomposed.dequantize_per_channel_group.default( |
| 269 | + weight, |
| 270 | + weight_scales, |
| 271 | + weight_zero_points, |
| 272 | + weight_quant_min, |
| 273 | + weight_quant_max, |
| 274 | + weight.dtype, |
| 275 | + group_size, |
| 276 | + dtype, |
| 277 | + ) |
| 278 | + return torch.ops.aten.embedding.default(weight, indices) |
| 279 | + |
| 280 | + |
| 281 | +@impl_abstract("quantized_decomposed::embedding_4bit.dtype_out") |
| 282 | +def embedding_4bit_dtype_out_meta( |
| 283 | + weight: torch.Tensor, |
| 284 | + weight_scales: torch.Tensor, |
| 285 | + weight_zero_points: Optional[torch.Tensor], |
| 286 | + weight_quant_min: int, |
| 287 | + weight_quant_max: int, |
| 288 | + indices: torch.Tensor, |
| 289 | + dtype: Optional[torch.dtype], |
| 290 | + out: torch.Tensor, |
| 291 | +) -> torch.Tensor: |
| 292 | + return embedding_4bit_dtype( |
| 293 | + weight, |
| 294 | + weight_scales, |
| 295 | + weight_zero_points, |
| 296 | + weight_quant_min, |
| 297 | + weight_quant_max, |
| 298 | + indices, |
| 299 | + dtype, |
| 300 | + ) |
| 301 | + |
| 302 | + |
37 | 303 | quantized_decomposed_lib.define(
|
38 | 304 | "mixed_mm(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points) -> Tensor",
|
39 | 305 | )
|
|
0 commit comments