-
Notifications
You must be signed in to change notification settings - Fork 259
Paged attention #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Paged attention #425
Conversation
Co-authored-by: Jiong Gong <[email protected]>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/425
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/kv_cache.py
Outdated
HANDLED_FUNCTIONS = {} | ||
|
||
|
||
class PagedTensor(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we make it a tensor subclass and inherit torch.Tensor
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
torchao/kv_cache.py
Outdated
self, | ||
cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations. | ||
block_tables: torch.Tensor,#The block tables for each sequence in the batch which is used to mapping logical block to physical blocks. | ||
context_lens: torch.Tensor,#The context lens for each sequence in the batch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
context length is a concept of text generation, seems not generic to describe a tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have used size which represent the real cache tensor size(bs, num_key_value_heads, seq_lens, head_dim) like the dynamic cache to replace context.
"""Returns the maximum sequence length of the cached states. PagedAttentionCache does not have a maximum length.""" | ||
RuntimeError("PagedAttentionCache does not have a maximum sequence length.") | ||
|
||
def update( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function will be called inside the model forward. With python implementation with conditionals and loops here, would it work with torch.compile
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. We need more effort to integrate this PR to the huggingface to validate the end2end functionality for the torch.compile. I suggest we can review this PR in paralel. I will refine it if need.
cc @liangan1 @HDCharles what's the status of this PR - do we need additional work to land? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kv_cache.py should probably be in prototypes folder or experimental, at the very least in kernel or something rather than the top level. which is reserved and kept clean.
a further question is what is goal of this PR as far as usage, it looks like it only enables this for CPU but 90% of our techniques are cuda based, is cuda support intended to be a next step? I think its fine to add a cpu only technique but just want clarity as far as the plan and who we expect to use it.
in other comments you mention wanting to integrate huggingface for e2e tests but you can do that directly in torchao, see torchao/_models/llama where there are a bunch of techniques being tested and benchmarked there. Generally the first step towards getting someone to use a technique like this would be a benchmark demo of how it helps, without that, its a large burden on the user to figure out what its for.
finally there's no information as far as how this is supposed to be used. Just code enabling a feature. Something like an e2e demo in the llama benchmark would make that easier for someone to understand how to use but there should probably also be a .md file explaining what this does what is the intended use case and the basic api a user is expected to apply. The RFC link is useful in the PR but not to a random user stumbling across this kernel who most likely isn't going to check the PR notes.
if you want to move this into an experimental/prototype folder i think that's the minimum needed to land this, though it would be good to understand the path this is expected to take towards usage since i think this feature is interesting but i don't know who is going to use it given almost everything in the repo is cuda related.
Related RFC