diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index 5716ee4bf5cf..a9e7c9bff5bc 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1584,7 +1584,9 @@ def __init__( self.pattern = pattern self.add_prefix_space = add_prefix_space self.additional_special_tokens = ( - additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens + additional_special_tokens.keys() + if isinstance(additional_special_tokens, dict) + else additional_special_tokens ) def extract_vocab_merges_from_model(self, tiktoken_url: str): diff --git a/src/transformers/integrations/bitnet.py b/src/transformers/integrations/bitnet.py index d1fa65978dbc..492d6f123c9d 100644 --- a/src/transformers/integrations/bitnet.py +++ b/src/transformers/integrations/bitnet.py @@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: class BitLinear(nn.Module): - def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool, + device=None, + dtype=None, + use_rms_norm: bool = False, + rms_norm_eps: float = 1e-6, + ): super().__init__() self.dtype = dtype self.in_features = in_features @@ -150,6 +159,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None, else: self.bias = None + # Optional RMSNorm (applied on the activations before quantization). + self.rms_norm = None + if use_rms_norm: + from ..models.llama.modeling_llama import LlamaRMSNorm + + self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps) + @torch.compile def activation_quant(self, input, num_bits=8): """ @@ -180,6 +196,10 @@ def post_quant_process(self, input, input_scale, weight_scale): return out def forward(self, input): + # Apply RMSNorm on the input if requested. + if self.rms_norm is not None: + input = self.rms_norm(input) + w = self.weight w_quant = unpack_weights(w, dtype=self.dtype) input_quant, input_scale = self.activation_quant(input) @@ -245,9 +265,17 @@ def __init__( device=None, dtype=None, online_quant: bool = False, + use_rms_norm: bool = False, + rms_norm_eps: float = 1e-6, ): super().__init__(in_features, out_features, bias) self.online_quant = online_quant + # Optional RMSNorm + self.rms_norm = None + if use_rms_norm: + from ..models.llama.modeling_llama import LlamaRMSNorm + + self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps) if not online_quant: self.register_buffer( "weight_scale", @@ -271,6 +299,10 @@ def load_hook( return state_dict def forward(self, input): + # Optional RMSNorm on activations prior to quantization. + if self.rms_norm is not None: + input = self.rms_norm(input) + if self.online_quant: weight = WeightQuant.apply(self.weight) else: @@ -318,6 +350,8 @@ def _replace_with_bitnet_linear( device=module.weight.device, dtype=module.weight.dtype, online_quant=(quantization_config.quantization_mode == "online"), + use_rms_norm=quantization_config.use_rms_norm, + rms_norm_eps=quantization_config.rms_norm_eps, ) if quantization_config.quantization_mode == "offline": model._modules[name].requires_grad_(False) @@ -328,6 +362,8 @@ def _replace_with_bitnet_linear( bias=module.bias is not None, device=module.weight.device, dtype=module.weight.dtype, + use_rms_norm=quantization_config.use_rms_norm, + rms_norm_eps=quantization_config.rms_norm_eps, ) model._modules[name].requires_grad_(False) has_been_replaced = True @@ -363,7 +399,7 @@ def replace_with_bitnet_linear( model (`torch.nn.Module`): Input model or `torch.nn.Module` as the function is run recursively. modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): - Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision + Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision for numerical stability reasons. current_key_name (`List[`str`]`, *optional*): An array to track the current key of the recursion. This is used to check whether the current key (part of diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index ec2a6c76deee..ee9a9c36af2f 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin): In `offline` mode, quantization parameters are pre-calculated *before* inference. These parameters are then fixed and loaded into the quantized model. This generally results in lower runtime overhead compared to online quantization. + use_rms_norm (`bool`, *optional*, defaults to `False`): + Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach + of normalizing activations before quantization/packing. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon value used in the RMSNorm layer for numerical stability. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments that may be used by specific quantization backends or future versions. @@ -1801,6 +1806,8 @@ def __init__( modules_to_not_convert: Optional[List] = None, linear_class: Optional[str] = "bitlinear", quantization_mode: Optional[str] = "offline", + use_rms_norm: Optional[bool] = False, + rms_norm_eps: Optional[float] = 1e-6, **kwargs, ): if linear_class not in ["bitlinear", "autobitlinear"]: @@ -1811,6 +1818,8 @@ def __init__( self.modules_to_not_convert = modules_to_not_convert self.linear_class = linear_class self.quantization_mode = quantization_mode + self.use_rms_norm = use_rms_norm + self.rms_norm_eps = rms_norm_eps self.post_init() def post_init(self):