Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,25 @@ def _dequantize(self, model):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

@staticmethod
def get_modules_to_not_convert(
model: "PreTrainedModel",
skip_modules: Optional[List[str]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
):
from ..integrations import get_keys_to_not_convert

modules_to_not_convert = []
if skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = skip_modules

if keep_in_fp32_modules is not None:
modules_to_not_convert.extend(keep_in_fp32_modules)

return modules_to_not_convert

@property
def is_qat_trainable(self) -> bool:
"""Flag indicating whether the quantized model can carry out quantization aware training"""
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/quantizers/quantizer_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Optional

from packaging import version

Expand Down Expand Up @@ -96,13 +96,14 @@ def update_torch_dtype(self, torch_dtype):
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
return torch_dtype

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert, replace_quantization_scales, replace_with_awq_linear
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
from ..integrations import replace_quantization_scales, replace_with_awq_linear

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=self.quantization_config, modules_to_not_convert=self.modules_to_not_convert
Expand Down
14 changes: 6 additions & 8 deletions src/transformers/quantizers/quantizer_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Dict, List, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from .base import HfQuantizer

Expand Down Expand Up @@ -81,16 +81,14 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bitnet_linear
from ..integrations import replace_with_bitnet_linear

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model = replace_with_bitnet_linear(
model,
Expand Down
17 changes: 5 additions & 12 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,23 +288,16 @@ def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
from ..integrations import replace_with_bnb_linear

llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.llm_int8_skip_modules is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)

# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
Expand Down
17 changes: 5 additions & 12 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,23 +245,16 @@ def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_bnb_linear
from ..integrations import replace_with_bnb_linear

llm_int8_enable_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.llm_int8_skip_modules is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.llm_int8_skip_modules, keep_in_fp32_modules
)

# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
if isinstance(device_map, dict) and len(device_map.keys()) > 1:
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/quantizers/quantizer_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,14 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_eetq_linear
from ..integrations import replace_with_eetq_linear

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model = replace_with_eetq_linear(
model,
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/quantizers/quantizer_fbgemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations import get_keys_to_not_convert, replace_with_fbgemm_fp8_linear
from ..integrations import replace_with_fbgemm_fp8_linear

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model = replace_with_fbgemm_fp8_linear(
model,
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/quantizers/quantizer_finegrained_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,14 @@ def check_quantized_param(
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
modules_to_not_convert: List[str] = [],
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
from ..integrations.finegrained_fp8 import replace_with_fp8_linear

self.modules_to_not_convert = ["lm_head"] + modules_to_not_convert

if self.quantization_config.modules_to_not_convert:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model = replace_with_fp8_linear(
model,
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,8 @@ def forward_with_device(self, x):
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = None,
**kwargs,
):
keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else []

# Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
# prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
Expand Down
17 changes: 5 additions & 12 deletions src/transformers/quantizers/quantizer_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,13 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
)

def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers
from ..integrations import replace_with_quanto_layers

# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.modules_to_not_convert is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert

if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

model, _ = replace_with_quanto_layers(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
Expand Down
9 changes: 7 additions & 2 deletions src/transformers/quantizers/quantizer_spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from .base import HfQuantizer

Expand Down Expand Up @@ -65,12 +65,17 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

replace_with_spqr_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=self.quantization_config.modules_to_not_convert,
modules_to_not_convert=self.modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config

Expand Down
16 changes: 7 additions & 9 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import importlib
import types
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Optional, Union

from packaging import version

Expand Down Expand Up @@ -144,14 +144,12 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
max_memory = {key: val * 0.9 for key, val in max_memory.items()}
return max_memory

def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
from ..integrations import get_keys_to_not_convert

self.modules_to_not_convert = get_keys_to_not_convert(model)

if self.quantization_config.modules_to_not_convert is not None:
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)

def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: Optional[List[str]] = None, **kwargs
):
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
return

def check_quantized_param(
Expand Down
9 changes: 5 additions & 4 deletions src/transformers/quantizers/quantizer_vptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from .base import HfQuantizer

Expand Down Expand Up @@ -68,6 +68,7 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
keep_in_fp32_modules: Optional[List[str]] = None,
**kwargs,
):
"""
Expand All @@ -76,14 +77,14 @@ def _process_model_before_weight_loading(
"""
from ..integrations import replace_with_vptq_linear

modules_to_not_convert = kwargs.get("modules_to_not_convert", []) + (
self.quantization_config.modules_to_not_convert or []
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)

replace_with_vptq_linear(
model,
quantization_config=self.quantization_config,
modules_to_not_convert=modules_to_not_convert,
modules_to_not_convert=self.modules_to_not_convert,
)
model.config.quantization_config = self.quantization_config

Expand Down
8 changes: 0 additions & 8 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,8 +1424,6 @@ def __init__(
tune_metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
if modules_to_not_convert is None:
modules_to_not_convert = ["lm_head"]
if tune_metadata is None:
tune_metadata = {}
self.quant_method = QuantizationMethod.HIGGS
Expand Down Expand Up @@ -1652,8 +1650,6 @@ def __init__(
self.bits = bits
self.beta1 = beta1
self.beta2 = beta2
if modules_to_not_convert is None:
modules_to_not_convert = []
self.modules_to_not_convert = modules_to_not_convert
self.post_init()

Expand All @@ -1674,10 +1670,6 @@ def post_init(self):
raise ValueError("SpQR currently only supports beta1 = 16")
if self.beta2 != 16:
raise ValueError("SpQR currently only supports beta2 = 16")

if self.modules_to_not_convert is not None and not isinstance(self.modules_to_not_convert, list):
raise ValueError("modules_to_not_convert must be a list of strings")

if not isinstance(self.shapes, dict):
raise TypeError("shapes must be a dict")

Expand Down