-
Notifications
You must be signed in to change notification settings - Fork 536
Add a simple sdpa #3037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Add a simple sdpa #3037
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
3bf691b
Add a simple sdpa
cccclai d6c63d1
Update on "Add a simple sdpa"
cccclai 4bcbdce
Update on "Add a simple sdpa"
cccclai 00347ea
Update on "Add a simple sdpa"
cccclai 8e99c7a
Update on "Add a simple sdpa"
cccclai f5ec6cf
Update on "Add a simple sdpa"
cccclai 5465fb7
Update on "Add a simple sdpa"
cccclai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
import argparse | ||
import copy | ||
import logging | ||
import math | ||
import os | ||
import shlex | ||
|
||
|
@@ -143,6 +144,80 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: | |
return module | ||
|
||
|
||
class SDPASimple(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
kv_cache: KVCache, | ||
dim: int, | ||
head_dim: int, | ||
n_rep: int, | ||
): | ||
super().__init__() | ||
self.kv_cache = kv_cache | ||
self.dim = dim | ||
self.head_dim = head_dim | ||
self.n_rep = n_rep | ||
|
||
def forward( | ||
self, | ||
input_pos: torch.Tensor, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
bsz, | ||
seqlen, | ||
mask, | ||
): | ||
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) | ||
k = k.transpose(1, 2) | ||
v = v.transpose(1, 2) | ||
|
||
k, v = self.kv_cache.update(input_pos, k, v) | ||
attn_mask = mask[None, None, input_pos] | ||
|
||
k = k.repeat_interleave(self.n_rep, dim=1) | ||
v = v.repeat_interleave(self.n_rep, dim=1) | ||
scale_factor = 1 / math.sqrt(q.size(-1)) | ||
attn_weight = q @ k.transpose(-2, -1) * scale_factor | ||
attn_weight += attn_mask | ||
attn_weight = torch.softmax(attn_weight, dim=-1) | ||
y = attn_weight @ v | ||
|
||
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need contiguous? |
||
|
||
|
||
def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): | ||
for name, child in module.named_children(): | ||
if isinstance(child, SDPA): | ||
setattr( | ||
module, | ||
name, | ||
SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), | ||
) | ||
else: | ||
replace_sdpa_with_simple_sdpa(child) | ||
return module | ||
|
||
|
||
def replace_causal_mask(module: torch.nn.Module): | ||
for buffer_fqn_name, buffer in module.named_buffers(): | ||
buffer_name = buffer_fqn_name.split(".")[-1] | ||
if buffer_name == "mask": | ||
max_seq_len = buffer.shape[-1] | ||
mask = torch.full( | ||
(max_seq_len, max_seq_len), | ||
float("-inf"), | ||
device="cpu", | ||
) | ||
|
||
mask = torch.triu(mask, diagonal=1) | ||
module.register_buffer(buffer_name, mask) | ||
for _, child in module.named_children(): | ||
replace_causal_mask(child) | ||
return module | ||
|
||
|
||
def quantize( | ||
model: torch.nn.Module, | ||
qmode: str, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") | ||
|
||
oncall("executorch") | ||
|
||
python_unittest( | ||
name = "test_simple_sdpa", | ||
srcs = [ | ||
"test_simple_sdpa.py", | ||
], | ||
deps = [ | ||
"//caffe2:torch", | ||
"//executorch/examples/models/llama2:export_library", | ||
"//executorch/examples/models/llama2:llama_transformer", | ||
], | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
import unittest | ||
|
||
import torch | ||
from executorch.examples.models.llama2.export_llama_lib import SDPASimple | ||
from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA | ||
|
||
|
||
class SDPATest(unittest.TestCase): | ||
def test_simple_sdpa(self): | ||
# Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py | ||
max_batch_size = 1 | ||
max_seq_length = 128 | ||
n_heads = 8 | ||
head_dim = 8 | ||
dim = 64 | ||
n_rep = 1 | ||
bsz = 1 | ||
seqlen = 1 | ||
n_local_heads = n_heads | ||
kv_cache = KVCache( | ||
max_batch_size=max_batch_size, | ||
max_seq_length=max_seq_length, | ||
n_heads=n_heads, | ||
head_dim=head_dim, | ||
transpose_cache=True, | ||
) | ||
sdpa = SDPA( | ||
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep | ||
) | ||
input_pos = torch.tensor([0]) | ||
query = torch.randn(1, 1, n_local_heads, head_dim) | ||
key = torch.randn(1, 1, n_local_heads, head_dim) | ||
value = torch.randn(1, 1, n_local_heads, head_dim) | ||
mask = torch.randn(max_seq_length, max_seq_length) | ||
sdpa_output = sdpa( | ||
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask | ||
) | ||
|
||
simple_sdpa = SDPASimple( | ||
kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep | ||
) | ||
simple_sdpa_output = simple_sdpa( | ||
input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask | ||
) | ||
|
||
# Compare the output from output from two sdpa implementation | ||
self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just keep sdpa as is by registering it as a custom qnn op that qnn delegate can directly consume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not just for qnn. Basically for any backend that doesn't support sdpa, it makes more sense to use this version instead.