diff --git a/examples/models/llama2/quantize.py b/examples/models/llama2/quantize.py index a1acc7695f6..a892e385bc2 100644 --- a/examples/models/llama2/quantize.py +++ b/examples/models/llama2/quantize.py @@ -122,6 +122,10 @@ def dynamically_quantize_per_channel( return quant, scales, zero_points +######################################################################### +### QuantHandler API definition ### + + class QuantHandler: def __init__(self, mod): self.mod = mod @@ -132,8 +136,15 @@ def create_quantized_state_dict(self) -> Dict: # "StateDict" def convert_for_runtime(self) -> nn.Module: pass + def quantized_model(self) -> nn.Module: + model_updated_state_dict = self.create_quantized_state_dict() + self.convert_for_runtime() + self.mod.load_state_dict(model_updated_state_dict) + return self.mod -##### Weight-only int8 per-channel quantized code ###### + +######################################################################### +### Weight-only int8 per-channel quantized code ### def replace_linear_weight_only_int8_per_channel(module, node_type): @@ -151,16 +162,17 @@ def replace_linear_weight_only_int8_per_channel(module, node_type): setattr( module, name, - WeightOnlyInt8Linear(child.in_features, child.out_features), + WeightOnlyInt8Linear("cpu", child.in_features, child.out_features), ) else: replace_linear_weight_only_int8_per_channel(child, node_type) -class WeightOnlyInt8QuantHandler: +class WeightOnlyInt8QuantHandler(QuantHandler): def __init__( self, mod, + device="cpu", *, node_type: str = "*", bitwidth: Optional[int] = None, @@ -200,7 +212,7 @@ def create_quantized_state_dict(self) -> Dict: ) ): print( - f"quantize {self.node_type} {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}" + f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" ) # print(f"initial weight shape {mod.weight.shape}") @@ -217,7 +229,7 @@ def create_quantized_state_dict(self) -> Dict: ) cur_state_dict[f"{fqn}.weight"] = weight - # squeeze makes groupsize=rowsize unidimensional + # squeeze makes group_size=rowsize unidimensional cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) return cur_state_dict @@ -241,10 +253,10 @@ class WeightOnlyInt8Linear(torch.nn.Module): def __init__( self, + device, in_features: int, out_features: int, bias: bool = True, - device=None, dtype=None, ) -> None: super().__init__() @@ -260,11 +272,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # return F.linear(input, self.weight.to(dtype=input.dtype)) * se... -##### embedding table quantization ###### +######################################################################### +##### embedding table quantization ###### def replace_embedding_weight_only_grouped_int8_per_channel( - module, bitwidth: int = 8, group_size: Optional[int] = None + module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False ): for name, child in module.named_children(): # print(f"name: {name}") @@ -275,25 +288,41 @@ def replace_embedding_weight_only_grouped_int8_per_channel( module, name, QuantizedGroupEmbedding( + device=device, vocab_size=child.weight.shape[0], embedding_dim=child.weight.shape[1], group_size=group_size, + packed=packed, ), ) else: replace_embedding_weight_only_grouped_int8_per_channel( - child, bitwidth, group_size + child, device, bitwidth, group_size, packed ) -class EmbeddingOnlyInt8QuantHandler: - def __init__(self, mod, *, bitwidth: int = 8, group_size: Optional[int] = None): +class EmbeddingOnlyInt8QuantHandler(QuantHandler): + def __init__( + self, + mod, + device="cpu", + *, + bitwidth: int = 8, + group_size: Optional[int] = None, + packed=False, + ): + if isinstance(packed, str): + packed = packed == "True" self.mod = mod + self.device = device self.group_size = group_size self.bitwidth = bitwidth + self.packed = packed + if (bitwidth != 4) and packed: + raise RuntimeError("pack only works with bitsize 4") @torch.no_grad() - def create_quantized_state_dict(self) -> Dict: + def create_quantized_state_dict(self, packed=False) -> Dict: cur_state_dict = self.mod.state_dict() if self.bitwidth == 4: @@ -306,18 +335,14 @@ def create_quantized_state_dict(self) -> Dict: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") for fqn, mod in self.mod.named_modules(): - if ( - isinstance(mod, nn.Embedding) - or isinstance(mod, fsEmbedding) - or isinstance(mod, fsStandardEmbedding) - ): + if isinstance(mod, nn.Embedding): # print("****") # print(f"Embedding identified: {fqn, mod}") # print(f"weights size: {mod.weight.size()}") # print(f"quantize {fqn}...") print( - f"quantize {fqn, mod} with groupsize {self.group_size}, bitwidth {self.bitwidth}" + f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" ) weight, scales, _ = dynamically_quantize_per_channel( mod.weight.float(), @@ -328,21 +353,36 @@ def create_quantized_state_dict(self) -> Dict: scales_dtype=mod.weight.dtype, ) + if packed: + if weight.shape[-1] % 2 != 0: + raise RuntimeError("automatic padding not implemented yet") + + weight_range_shifted = weight.add(8).view(torch.uint8) + weight_view = weight_range_shifted.view( + weight.shape[0], weight.shape[1] // 2, 2 + ) + weight_even = weight_view[:, :, 0] * 16 # left shift 4 + weight_odd = weight_view[:, :, 1] + weight_packed = weight_even + weight_odd + weight = weight_packed + + weight = weight.to(device=self.device) + scales = scales.to(device=self.device) # Update state dict cur_state_dict[f"{fqn}.weight"] = weight - # squeeze makes groupsize=rowsize unidimensional + # squeeze makes group_size=rowsize unidimensional cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) return cur_state_dict def convert_for_runtime(self) -> nn.Module: replace_embedding_weight_only_grouped_int8_per_channel( - self.mod, self.bitwidth, self.group_size + self.mod, self.device, self.bitwidth, self.group_size, self.packed ) return self.mod def quantized_model(self) -> nn.Module: - model_updated_state_dict = self.create_quantized_state_dict() + model_updated_state_dict = self.create_quantized_state_dict(self.packed) self.convert_for_runtime() self.mod.load_state_dict(model_updated_state_dict) return self.mod @@ -351,39 +391,53 @@ def quantized_model(self) -> nn.Module: class QuantizedGroupEmbedding(torch.nn.Module): def __init__( self, + device, vocab_size: int, embedding_dim: int, group_size: Optional[int] = None, - device=None, dtype=torch.half, + packed=False, ) -> None: super().__init__() - if group_size is None: + if group_size is None or group_size == 0: group_size = embedding_dim self.group_size = group_size self.dtype = dtype - self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8) - ) + self.packed = packed + if not packed: + self.register_buffer( + "weight", + torch.empty( + (vocab_size, embedding_dim), dtype=torch.int8, device=device + ), + ) + else: # packed + self.register_buffer( + "weight", + torch.empty( + (vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device + ), + ) groups_per_row = (embedding_dim + group_size - 1) // group_size if groups_per_row > 1: self.register_buffer( - "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16) + "scales", + torch.ones( + (vocab_size, groups_per_row), dtype=torch.float16, device=device + ), ) else: self.register_buffer( - "scales", torch.ones((vocab_size,), dtype=torch.float16) + "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) ) @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: - return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype( - self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype - ) - - -# result_weights = self.weight.index_select(0, indices.view(-1)) -# result_scales = self.scales.index_select(0, indices.view(-1)) -# -# r = result_weights.to(dtype=result_scales.dtype) * result_scales -# return r.view(indices.size() + (-1,)) + if not self.packed: # 8bit + return torch.ops.llama_quantized.DEPRECATED_DO_NOT_USE_embedding_byte.dtype( + self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype + ) + else: # 4bit packed + return torch.ops.llama_quantized.embedding_4bit.dtype( + self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype + )