Skip to content

forward pass consumes 2x memory than required in Attention module when use_cache=True #25930

@RahulSChand

Description

@RahulSChand

Feature request

Currently the forward pass of the model consumes 2x memory than is required in attention module when use_cache=True.

For example during generation, in the attention module of llama key_states is of shape (bsz, heads, 1, head_size)
& past_key_value[0] is of shape (bsz, heads, seq_length, head_size).

When we do torch.cat(..) in line 337 below, we get two copies, past_key_value[0] one of shape (bsz, heads, seq_length, head_size) and the key_states of shape (bsz, heads, seq_length+1, head_size).

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

Ideally we want the tensor that past_key_value[0] is pointing to be freed since it's no longer used (& will be replaced by the newly created key_states). In current implementation if (bsz=1,heads=12,seq_length=1024, head_size=128) then memory consumed is bsz*heads*seq_length*head_size*layers*2 whereas it should ideally only use half of it. This can be achieved by just freeing the past_key_value[0] after cat operation finishes.

This is particularly noticable when you increase bsz. Freeing the memory allows using 2x more batch size or sequence length. There are some edge cases where you might use the past_key_value[0] again so maybe there can be flag to switch this on/off.

Motivation

Reduce memory consumed. by 2x when use_cache=True. This allows us to increase batch size by 2x or max seq length of tokens generated by 2x with same memory.

Your contribution

I can contribute if this is a change that is required. The change is small on surface since all it requires is freeing the tensor that past_key_value[0] & past_key_value[1] point to after Line 338

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions