Skip to content

Commit efdd436

Browse files
FIX [PEFT / Trainer ] Handle better peft + quantized compiled models (#29055)
* handle peft + compiled models * add tests * fixup * adapt from suggestions * clarify comment
1 parent 5e95dca commit efdd436

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/transformers/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,12 @@ def __init__(
429429
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
430430
)
431431

432+
# Filter out quantized + compiled models
433+
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
434+
raise ValueError(
435+
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
436+
)
437+
432438
# At this stage the model is already loaded
433439
if _is_quantized_and_base_model and not _is_peft_model(model):
434440
raise ValueError(

tests/trainer/test_trainer.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
require_deepspeed,
6363
require_intel_extension_for_pytorch,
6464
require_optuna,
65+
require_peft,
6566
require_ray,
6667
require_safetensors,
6768
require_sentencepiece,
@@ -873,6 +874,42 @@ def test_number_of_steps_in_training_with_ipex(self):
873874
train_output = trainer.train()
874875
self.assertEqual(train_output.global_step, 10)
875876

877+
@require_peft
878+
@require_bitsandbytes
879+
def test_bnb_compile(self):
880+
from peft import LoraConfig, get_peft_model
881+
882+
# Simply tests if initializing a Trainer with a PEFT + compiled model works out of the box
883+
# QLoRA + torch compile is not really supported yet, but we should at least support the model
884+
# loading and let torch throw the
885+
tiny_model = AutoModelForCausalLM.from_pretrained(
886+
"hf-internal-testing/tiny-random-LlamaForCausalLM", load_in_4bit=True
887+
)
888+
889+
peft_config = LoraConfig(
890+
r=8,
891+
lora_alpha=32,
892+
target_modules=["q_proj", "k_proj", "v_proj"],
893+
lora_dropout=0.05,
894+
bias="none",
895+
task_type="CAUSAL_LM",
896+
)
897+
tiny_model = get_peft_model(tiny_model, peft_config)
898+
899+
tiny_model = torch.compile(tiny_model)
900+
901+
x = torch.randint(0, 100, (128,))
902+
train_dataset = RepeatDataset(x)
903+
904+
with tempfile.TemporaryDirectory() as tmp_dir:
905+
args = TrainingArguments(
906+
tmp_dir,
907+
learning_rate=1e-9,
908+
logging_steps=5,
909+
)
910+
with self.assertRaises(ValueError):
911+
_ = Trainer(tiny_model, args, train_dataset=train_dataset) # noqa
912+
876913
@require_bitsandbytes
877914
def test_rmsprop_bnb(self):
878915
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)

0 commit comments

Comments
 (0)