-
Notifications
You must be signed in to change notification settings - Fork 0
Cache class working with generate #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d53eea6
78bce79
3d6316c
3c9e09a
b1e65ac
4783ca0
b9e95d4
cddb420
a269f15
7ee08ad
3629821
468daad
622b83b
9e97516
7182159
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -279,11 +279,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
| class LlamaAttention(nn.Module): | ||
| """Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
|
||
| def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): | ||
| def __init__(self, config: LlamaConfig, layer_idx: int): | ||
|
||
| super().__init__() | ||
| self.config = config | ||
| self.attention_dropout = config.attention_dropout | ||
| self.layer_idx = layer_idx | ||
| self.attention_dropout = config.attention_dropout | ||
| self.hidden_size = config.hidden_size | ||
| self.num_heads = config.num_attention_heads | ||
| self.head_dim = self.hidden_size // self.num_heads | ||
|
|
@@ -430,7 +430,7 @@ def forward( | |
| if not output_attentions: | ||
| attn_weights = None | ||
|
|
||
| return attn_output, attn_weights, (past_key_value if use_cache else None) | ||
| return attn_output, attn_weights, past_key_value | ||
|
||
|
|
||
|
|
||
| class LlamaFlashAttention2(LlamaAttention): | ||
|
|
@@ -524,7 +524,7 @@ def forward( | |
| if not output_attentions: | ||
| attn_weights = None | ||
|
|
||
| return attn_output, attn_weights, (past_key_value if use_cache else None) | ||
| return attn_output, attn_weights, past_key_value | ||
|
|
||
| def _flash_attention_forward( | ||
| self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None | ||
|
|
@@ -619,7 +619,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query | |
|
|
||
|
|
||
| class LlamaDecoderLayer(nn.Module): | ||
| def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): | ||
| def __init__(self, config: LlamaConfig, layer_idx: int): | ||
| super().__init__() | ||
| self.hidden_size = config.hidden_size | ||
| self.self_attn = ( | ||
|
|
@@ -795,6 +795,9 @@ def _init_weights(self, module): | |
| more detail. | ||
| return_dict (`bool`, *optional*): | ||
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | ||
| use_legacy_cache (`bool`, *optional*): | ||
| If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return | ||
| a subclass of `Cache` | ||
| """ | ||
|
|
||
|
|
||
|
|
@@ -866,7 +869,7 @@ def forward( | |
| past_key_values_length = 0 | ||
| if use_cache: | ||
| if not isinstance(past_key_values, Cache): | ||
| past_key_values = self.from_legacy_cache(past_key_values) | ||
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
| past_key_values_length = past_key_values.get_seq_length() | ||
|
|
||
| if position_ids is None: | ||
|
|
@@ -943,7 +946,7 @@ def forward( | |
|
|
||
| next_cache = None | ||
| if use_cache: | ||
| next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache | ||
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache | ||
| if not return_dict: | ||
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
| return BaseModelOutputWithPast( | ||
|
|
@@ -953,12 +956,6 @@ def forward( | |
| attentions=all_self_attns, | ||
| ) | ||
|
|
||
| def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache: | ||
|
||
| return DynamicCache.from_legacy_cache(past_key_values) | ||
|
|
||
| def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]: | ||
| return past_key_values.to_legacy_cache() | ||
|
|
||
|
|
||
| class LlamaForCausalLM(LlamaPreTrainedModel): | ||
| _tied_weights_keys = ["lm_head.weight"] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.