Skip to content

Paged attention #425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Conversation

liangan1
Copy link

@liangan1 liangan1 commented Jun 24, 2024

Related RFC

Copy link

pytorch-bot bot commented Jun 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/425

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2024
@liangan1
Copy link
Author

@jgong5

HANDLED_FUNCTIONS = {}


class PagedTensor(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we make it a tensor subclass and inherit torch.Tensor here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self,
cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations.
block_tables: torch.Tensor,#The block tables for each sequence in the batch which is used to mapping logical block to physical blocks.
context_lens: torch.Tensor,#The context lens for each sequence in the batch.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context length is a concept of text generation, seems not generic to describe a tensor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have used size which represent the real cache tensor size(bs, num_key_value_heads, seq_lens, head_dim) like the dynamic cache to replace context.

"""Returns the maximum sequence length of the cached states. PagedAttentionCache does not have a maximum length."""
RuntimeError("PagedAttentionCache does not have a maximum sequence length.")

def update(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will be called inside the model forward. With python implementation with conditionals and loops here, would it work with torch.compile?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. We need more effort to integrate this PR to the huggingface to validate the end2end functionality for the torch.compile. I suggest we can review this PR in paralel. I will refine it if need.

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
@jcaip
Copy link
Contributor

jcaip commented Mar 19, 2025

cc @liangan1 @HDCharles what's the status of this PR - do we need additional work to land?

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kv_cache.py should probably be in prototypes folder or experimental, at the very least in kernel or something rather than the top level. which is reserved and kept clean.

a further question is what is goal of this PR as far as usage, it looks like it only enables this for CPU but 90% of our techniques are cuda based, is cuda support intended to be a next step? I think its fine to add a cpu only technique but just want clarity as far as the plan and who we expect to use it.

in other comments you mention wanting to integrate huggingface for e2e tests but you can do that directly in torchao, see torchao/_models/llama where there are a bunch of techniques being tested and benchmarked there. Generally the first step towards getting someone to use a technique like this would be a benchmark demo of how it helps, without that, its a large burden on the user to figure out what its for.

finally there's no information as far as how this is supposed to be used. Just code enabling a feature. Something like an e2e demo in the llama benchmark would make that easier for someone to understand how to use but there should probably also be a .md file explaining what this does what is the intended use case and the basic api a user is expected to apply. The RFC link is useful in the PR but not to a random user stumbling across this kernel who most likely isn't going to check the PR notes.

if you want to move this into an experimental/prototype folder i think that's the minimum needed to land this, though it would be good to understand the path this is expected to take towards usage since i think this feature is interesting but i don't know who is going to use it given almost everything in the repo is cuda related.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants