-
Notifications
You must be signed in to change notification settings - Fork 2k
Description
Feature request
Let's add a new quantization method to LoRA, namely optimum-quanto.
There is some more context in this diffusers issue.
Motivation
First of all, the more quantization methods we support the better. But notably, quanto also works with MPS, which distinguishes it from other quantization methods.
Your contribution
I did some preliminary testing and partly, quanto already works with PEFT, as the QLinear
layer is a subclass of nn.Linear
and as such, lora.Linear
is applied. Some features like inference appear to work already. However, some features don't work correctly, like merging. Here is a very quick test:
import torch
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoModelForCausalLM
torch.manual_seed(0)
inputs = torch.arange(5).view(-1, 1)
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").eval()
with torch.inference_mode():
output_base = model(inputs).logits
# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)
with torch.inference_mode():
output_quantized = model(inputs).logits
config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, init_lora_weights=False)
print("adding adapter (random)")
model = get_peft_model(model, config)
model.eval()
with torch.inference_mode():
output_lora = model(inputs).logits
with model.disable_adapter():
output_disabled = model(inputs).logits
output_after_disabled = model(inputs).logits
model.merge_adapter()
with torch.inference_mode():
output_merged = model(inputs).logits
model.unmerge_adapter()
with torch.inference_mode():
output_unmerged = model(inputs).logits
unloaded = model.merge_and_unload()
with torch.inference_mode():
output_unloaded = unloaded(inputs).logits
print("output_base")
print(output_base[0, 0, :5])
print("output_quantized")
print(output_quantized[0, 0, :5])
print("output_lora")
print(output_lora[0, 0, :5])
print("output_disabled")
print(output_disabled[0, 0, :5])
print("output_after_disabled")
print(output_after_disabled[0, 0, :5])
print("output_merged")
print(output_merged[0, 0, :5])
print("output_unmerged")
print(output_unmerged[0, 0, :5])
print("output_unloaded")
print(output_unloaded[0, 0, :5])
Note that all the outputs involving merging are not as expected.
I can certainly take this when I have time but contributions are highly welcome. For inspiration, check out past PRs that add new quantization methods.