Skip to content
78 changes: 78 additions & 0 deletions examples/quantization/custom_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
from typing import Any, Dict

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.quantizers import HfQuantizer, register_quantization_config, register_quantizer
from transformers.utils.quantization_config import QuantizationConfigMixin


@register_quantization_config("custom")
class CustomConfig(QuantizationConfigMixin):
def __init__(self):
self.quant_method = "custom"
self.bits = 8

def to_dict(self) -> Dict[str, Any]:
output = {
"num_bits": self.bits,
}
return output

def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"

def to_diff_dict(self) -> Dict[str, Any]:
config_dict = self.to_dict()

default_config_dict = CustomConfig().to_dict()

serializable_config_dict = {}

for key, value in config_dict.items():
if value != default_config_dict[key]:
serializable_config_dict[key] = value

return serializable_config_dict


@register_quantizer("custom")
class CustomQuantizer(HfQuantizer):
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
self.scale_map = {}
self.device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu")
self.torch_dtype = kwargs.get("torch_dtype", torch.float32)

def _process_model_before_weight_loading(self, model, **kwargs):
return True

def _process_model_after_weight_loading(self, model, **kwargs):
return True

def is_serializable(self) -> bool:
return True

def is_trainable(self) -> bool:
return False


model_8bit = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m", quantization_config=CustomConfig(), torch_dtype="auto"
)

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
input_text = "once there is"
inputs = tokenizer(input_text, return_tensors="pt")
output = model_8bit.generate(
**inputs,
max_length=100,
num_return_sequences=1,
no_repeat_ngram_size=2,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3706,8 +3706,10 @@ def from_pretrained(
device_map = hf_quantizer.update_device_map(device_map)

# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value

if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
else:
user_agent["quant"] = hf_quantizer.quantization_config.quant_method
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auto import AutoHfQuantizer, AutoQuantizationConfig
from .auto import AutoHfQuantizer, AutoQuantizationConfig, register_quantization_config, register_quantizer
from .base import HfQuantizer
33 changes: 33 additions & 0 deletions src/transformers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TorchAoConfig,
VptqConfig,
)
from .base import HfQuantizer
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bitnet import BitNetHfQuantizer
Expand Down Expand Up @@ -226,3 +227,35 @@ def supports_quant_method(quantization_config_dict):
)
return False
return True


def register_quantization_config(method: str):
"""Register a custom quantization configuration."""

def register_config_fn(cls):
if method in AUTO_QUANTIZATION_CONFIG_MAPPING:
raise ValueError(f"Config '{method}' already registered")

if not issubclass(cls, QuantizationConfigMixin):
raise ValueError("Config must extend QuantizationConfigMixin")

AUTO_QUANTIZATION_CONFIG_MAPPING[method] = cls
return cls

return register_config_fn


def register_quantizer(name: str):
"""Register a custom quantizer."""

def register_quantizer_fn(cls):
if name in AUTO_QUANTIZER_MAPPING:
raise ValueError(f"Quantizer '{name}' already registered")

if not issubclass(cls, HfQuantizer):
raise ValueError("Quantizer must extend HfQuantizer")

AUTO_QUANTIZER_MAPPING[name] = cls
return cls

return register_quantizer_fn