@@ -279,7 +279,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
279279class LlamaAttention (nn .Module ):
280280 """Multi-headed attention from 'Attention Is All You Need' paper"""
281281
282- def __init__ (self , config : LlamaConfig , layer_idx : Optional [ int ] = None ):
282+ def __init__ (self , config : LlamaConfig , layer_idx : int ):
283283 super ().__init__ ()
284284 self .config = config
285285 self .layer_idx = layer_idx
@@ -430,7 +430,7 @@ def forward(
430430 if not output_attentions :
431431 attn_weights = None
432432
433- return attn_output , attn_weights , ( past_key_value if use_cache else None )
433+ return attn_output , attn_weights , past_key_value
434434
435435
436436class LlamaFlashAttention2 (LlamaAttention ):
@@ -524,7 +524,7 @@ def forward(
524524 if not output_attentions :
525525 attn_weights = None
526526
527- return attn_output , attn_weights , ( past_key_value if use_cache else None )
527+ return attn_output , attn_weights , past_key_value
528528
529529 def _flash_attention_forward (
530530 self , query_states , key_states , value_states , attention_mask , query_length , dropout = 0.0 , softmax_scale = None
@@ -619,7 +619,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
619619
620620
621621class LlamaDecoderLayer (nn .Module ):
622- def __init__ (self , config : LlamaConfig , layer_idx : Optional [ int ] = None ):
622+ def __init__ (self , config : LlamaConfig , layer_idx : int ):
623623 super ().__init__ ()
624624 self .hidden_size = config .hidden_size
625625 self .self_attn = (
@@ -943,7 +943,7 @@ def forward(
943943
944944 next_cache = None
945945 if use_cache :
946- next_cache = self .to_legacy_cache (next_decoder_cache ) if use_legacy_cache else next_decoder_cache
946+ next_cache = next_decoder_cache .to_legacy_cache () if use_legacy_cache else next_decoder_cache
947947 if not return_dict :
948948 return tuple (v for v in [hidden_states , next_cache , all_hidden_states , all_self_attns ] if v is not None )
949949 return BaseModelOutputWithPast (
@@ -953,12 +953,6 @@ def forward(
953953 attentions = all_self_attns ,
954954 )
955955
956- def from_legacy_cache (self , past_key_values : Optional [Tuple [Tuple [torch .Tensor ]]]) -> Cache :
957- return DynamicCache .from_legacy_cache (past_key_values )
958-
959- def to_legacy_cache (self , past_key_values : Cache ) -> Tuple [Tuple [torch .Tensor ]]:
960- return past_key_values .to_legacy_cache ()
961-
962956
963957class LlamaForCausalLM (LlamaPreTrainedModel ):
964958 _tied_weights_keys = ["lm_head.weight" ]
0 commit comments