Skip to content

Commit cddb420

Browse files
committed
Add tests; Simplify code; Apply changes to Mistral and Persimmon
1 parent b9e95d4 commit cddb420

File tree

4 files changed

+344
-138
lines changed

4 files changed

+344
-138
lines changed

src/transformers/models/llama/modeling_llama.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
279279
class LlamaAttention(nn.Module):
280280
"""Multi-headed attention from 'Attention Is All You Need' paper"""
281281

282-
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
282+
def __init__(self, config: LlamaConfig, layer_idx: int):
283283
super().__init__()
284284
self.config = config
285285
self.layer_idx = layer_idx
@@ -430,7 +430,7 @@ def forward(
430430
if not output_attentions:
431431
attn_weights = None
432432

433-
return attn_output, attn_weights, (past_key_value if use_cache else None)
433+
return attn_output, attn_weights, past_key_value
434434

435435

436436
class LlamaFlashAttention2(LlamaAttention):
@@ -524,7 +524,7 @@ def forward(
524524
if not output_attentions:
525525
attn_weights = None
526526

527-
return attn_output, attn_weights, (past_key_value if use_cache else None)
527+
return attn_output, attn_weights, past_key_value
528528

529529
def _flash_attention_forward(
530530
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
@@ -619,7 +619,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
619619

620620

621621
class LlamaDecoderLayer(nn.Module):
622-
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
622+
def __init__(self, config: LlamaConfig, layer_idx: int):
623623
super().__init__()
624624
self.hidden_size = config.hidden_size
625625
self.self_attn = (
@@ -943,7 +943,7 @@ def forward(
943943

944944
next_cache = None
945945
if use_cache:
946-
next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache
946+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
947947
if not return_dict:
948948
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
949949
return BaseModelOutputWithPast(
@@ -953,12 +953,6 @@ def forward(
953953
attentions=all_self_attns,
954954
)
955955

956-
def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache:
957-
return DynamicCache.from_legacy_cache(past_key_values)
958-
959-
def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]:
960-
return past_key_values.to_legacy_cache()
961-
962956

963957
class LlamaForCausalLM(LlamaPreTrainedModel):
964958
_tied_weights_keys = ["lm_head.weight"]

0 commit comments

Comments
 (0)