Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/quantization/torchao.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementatio
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
```

> [!TIP]
> For best performance, you can use recommended settings by calling `torchao.quantization.utils.recommended_inductor_config_setter()`

</hfoption>
<hfoption id="automatic">

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def create_quantized_param(
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
quantize_(module, self.quantization_config.get_apply_tensor_subclass(), set_inductor_config=False)

def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
Expand All @@ -213,7 +213,7 @@ def _process_model_after_weight_loading(self, model, **kwargs):

model = torch.compile(model, mode="max-autotune")
model = autoquant(
model, qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST, **self.quantization_config.quant_type_kwargs
model, qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST, set_inductor_config=False, **self.quantization_config.quant_type_kwargs
)
return model
return
Expand Down