Skip to content

Commit ee6f6dc

Browse files
BenjaminBossanyounesbelkadagante
authored
FIX Issues with transformers 4.36 (#1252)
Adjust for different type of past_key_values when using caching. Also: Fix some seeds for flaky tests. --------- Co-authored-by: Younes Belkada <[email protected]> Co-authored-by: Joao Gante <[email protected]>
1 parent 21c304f commit ee6f6dc

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

src/peft/peft_model.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from copy import deepcopy
2424
from typing import Any, Dict, List, Optional, Union
2525

26+
import packaging.version
2627
import torch
28+
import transformers
2729
from accelerate import dispatch_model, infer_auto_device_map
2830
from accelerate.hooks import AlignDevicesHook, add_hook_to_module, remove_hook_from_submodules
2931
from accelerate.utils import get_balanced_memory
@@ -1136,11 +1138,26 @@ def generate(self, **kwargs):
11361138
def prepare_inputs_for_generation(self, *args, task_ids: torch.Tensor = None, **kwargs):
11371139
peft_config = self.active_peft_config
11381140
model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
1141+
1142+
# https://github.com/huggingface/transformers/pull/26681/ introduced new cache format
1143+
# for some architectures which requires a special fix for prompt tuning etc.
1144+
# TODO: starting with transformers 4.37, all architectures should support caching.
1145+
uses_transformers_4_37 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.37.0")
1146+
uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
1147+
transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
1148+
uses_cache = uses_transformers_4_37 or (
1149+
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
1150+
)
1151+
11391152
if peft_config.is_prompt_learning:
11401153
if model_kwargs.get("attention_mask", None) is not None:
1141-
prefix_attention_mask = torch.ones(
1142-
model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
1143-
).to(model_kwargs["input_ids"].device)
1154+
if uses_cache and (model_kwargs["past_key_values"] is not None):
1155+
# TODO figure out why this workaround is necessary, see #1252 for context
1156+
size = model_kwargs["input_ids"].shape[0], model_kwargs["past_key_values"][0][0].shape[-2]
1157+
else:
1158+
size = model_kwargs["input_ids"].shape[0], peft_config.num_virtual_tokens
1159+
1160+
prefix_attention_mask = torch.ones(size).to(model_kwargs["input_ids"].device)
11441161
model_kwargs["attention_mask"] = torch.cat(
11451162
(prefix_attention_mask, model_kwargs["attention_mask"]), dim=1
11461163
)

src/peft/tuners/adaption_prompt/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
7373

7474
seq_len = q_len
7575
if past_key_value is not None:
76-
seq_len += past_key_value[0].shape[-2]
76+
if isinstance(past_key_value, tuple):
77+
# for transformers <= 4.35
78+
seq_len += past_key_value[0].shape[-2]
79+
else:
80+
# since transformers 4.36, this is a DynamicCache instance
81+
seq_len += past_key_value.get_seq_length(model.layer_idx)
7782
cos, sin = model.rotary_emb(value_states, seq_len=seq_len)
7883

7984
return llama_apply_rotary_pos_emb(query_states, cos, sin, position_ids)

tests/test_adaption_prompt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def test_add_and_set_while_disabled(self):
387387

388388
def test_use_cache(self) -> None:
389389
"""Test that AdaptionPrompt works when Llama config use_cache=True."""
390+
torch.manual_seed(0)
390391
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
391392
original = LlamaForCausalLM(
392393
LlamaConfig(
@@ -397,7 +398,7 @@ def test_use_cache(self) -> None:
397398
num_attention_heads=4,
398399
use_cache=False,
399400
)
400-
)
401+
).eval()
401402
adapted = get_peft_model(
402403
original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
403404
)

tests/test_multitask_prompt_tuning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,11 @@ def test_generate(self) -> None:
220220

221221
def test_use_cache(self) -> None:
222222
"""Test that MultiTaskPromptTuning works when Llama config use_cache=True."""
223+
torch.manual_seed(0)
223224
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
224225
task_ids = torch.LongTensor([1, 2]).to(self.torch_device)
225226

226-
original = LlamaForCausalLM(self._create_test_llama_config())
227+
original = LlamaForCausalLM(self._create_test_llama_config()).eval()
227228
mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
228229
mpt = mpt.to(self.torch_device)
229230

0 commit comments

Comments
 (0)