-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Description
Feature request
Currently the forward pass of the model consumes 2x memory than is required in attention module when use_cache=True.
For example during generation, in the attention module of llama key_states is of shape (bsz, heads, 1, head_size)
& past_key_value[0] is of shape (bsz, heads, seq_length, head_size).
When we do torch.cat(..) in line 337 below, we get two copies, past_key_value[0] one of shape (bsz, heads, seq_length, head_size) and the key_states of shape (bsz, heads, seq_length+1, head_size).
transformers/src/transformers/models/llama/modeling_llama.py
Lines 335 to 338 in 0afa507
| if past_key_value is not None: | |
| # reuse k, v, self_attention | |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
Ideally we want the tensor that past_key_value[0] is pointing to be freed since it's no longer used (& will be replaced by the newly created key_states). In current implementation if (bsz=1,heads=12,seq_length=1024, head_size=128) then memory consumed is bsz*heads*seq_length*head_size*layers*2 whereas it should ideally only use half of it. This can be achieved by just freeing the past_key_value[0] after cat operation finishes.
This is particularly noticable when you increase bsz. Freeing the memory allows using 2x more batch size or sequence length. There are some edge cases where you might use the past_key_value[0] again so maybe there can be flag to switch this on/off.
Motivation
Reduce memory consumed. by 2x when use_cache=True. This allows us to increase batch size by 2x or max seq length of tokens generated by 2x with same memory.
Your contribution
I can contribute if this is a change that is required. The change is small on surface since all it requires is freeing the tensor that past_key_value[0] & past_key_value[1] point to after Line 338
transformers/src/transformers/models/llama/modeling_llama.py
Lines 335 to 338 in 0afa507
| if past_key_value is not None: | |
| # reuse k, v, self_attention | |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |