Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/transformers/quantizers/quantizer_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,14 @@ def _find_hqq_quantizable_layers(model, layers):
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
_valid_modules = set()
_find_hqq_quantizable_layers(model, _valid_modules)
_valid_modules -= set(model.config.quantization_config["skip_modules"])

# Remove skipped modules
_skipped_modules = set()
for _module in _valid_modules:
for _skip_module in model.config.quantization_config["skip_modules"]:
if _skip_module in _module:
_skipped_modules.add(_module)
_valid_modules -= _skipped_modules

# Append new expected layers based on _ref_keys
_ref_keys = HQQLinear(
Expand Down Expand Up @@ -243,10 +250,24 @@ def create_quantized_param(

# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
if hasattr(module, "quant_config"):
quant_config = model.config.quantization_config["quant_config"]
skip_modules = model.config.quantization_config["skip_modules"]
module_tag = ".".join(module.name.split(".")[-2:])
module_quant_config = None
if "weight_quant_params" in quant_config:
module_quant_config = quant_config
elif module_tag in quant_config:
module_quant_config = quant_config[module_tag]

for skip_module in skip_modules:
if skip_module in module.name:
module_quant_config = None
break

if module_quant_config is not None:
hqq_layer = HQQLinear(
module,
module.quant_config,
quant_config=module_quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,
Expand Down
33 changes: 33 additions & 0 deletions tests/quantization/hqq/test_hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,36 @@ def test_model_serialization(self):
logits_loaded = model_loaded.forward(input_tensor).logits

self.assertEqual((logits_loaded - logits_ref).abs().mean().item(), 0)

def test_model_serialization_dynamic_quant_with_skip(self):
"""
Simple HQQ LLM save/load test with dynamic quant
"""
q4_config = {"nbits": 4, "group_size": 64}
q3_config = {"nbits": 3, "group_size": 64}

quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
},
skip_modules=["lm_head", "down_proj"],
)

hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)

model = hqq_runner.model

input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=torch_device)
with torch.no_grad():
model.forward(input_tensor).logits

self.assertEqual(isinstance(model.model.layers[1].mlp.down_proj, torch.nn.Linear), True)
self.assertEqual(model.model.layers[1].self_attn.v_proj.quant_config["weight_quant_params"]["nbits"], 4)
self.assertEqual(model.model.layers[1].mlp.gate_proj.quant_config["weight_quant_params"]["nbits"], 3)