Skip to content

Commit 61cabec

Browse files
committed
move mask as sdpa input instead of attribute (#3036)
Summary: Pull Request resolved: #3036 sdpa (https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) input is taking attention mask as input, refactor the sdpa module input closer to the sdpa input ghstack-source-id: 222650466 exported-using-ghexport Reviewed By: mergennachin Differential Revision: D56119739 fbshipit-source-id: d9adda66e540abc518b7ffb6a5ebd2aab1626b3b (cherry picked from commit b341223)
1 parent 0ad7043 commit 61cabec

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

examples/models/llama2/export_llama_lib.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,10 @@ class SDPACustom(torch.nn.Module):
9696
def __init__(
9797
self,
9898
kv_cache: KVCache,
99-
mask,
10099
dim: int,
101100
):
102101
super().__init__()
103102
self.kv_cache = kv_cache
104-
self.mask = mask
105103
self.dim = dim
106104

107105
def forward(
@@ -112,6 +110,7 @@ def forward(
112110
v: torch.Tensor,
113111
bsz,
114112
seqlen,
113+
mask,
115114
):
116115
output = torch.ops.llama.sdpa_with_kv_cache(
117116
q,
@@ -131,7 +130,7 @@ def _replace_sdpa_with_custom_op(module: torch.nn.Module):
131130
setattr(
132131
module,
133132
name,
134-
SDPACustom(child.kv_cache, child.mask, child.dim),
133+
SDPACustom(child.kv_cache, child.dim),
135134
)
136135
else:
137136
_replace_sdpa_with_custom_op(child)

examples/models/llama2/llama_transformer.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@ class SDPA(nn.Module):
197197
def __init__(
198198
self,
199199
kv_cache: KVCache,
200-
mask,
201200
dim: int,
201+
head_dim: int,
202202
n_rep: int,
203203
):
204204
super().__init__()
205205
self.kv_cache = kv_cache
206-
self.mask = mask
207206
self.dim = dim
207+
self.head_dim = head_dim
208208
self.n_rep = n_rep
209209

210210
def forward(
@@ -215,17 +215,18 @@ def forward(
215215
v: torch.Tensor,
216216
bsz,
217217
seqlen,
218+
mask: torch.Tensor,
218219
) -> torch.Tensor:
219220
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
220221
k = k.transpose(1, 2)
221222
v = v.transpose(1, 2)
222223

223224
k, v = self.kv_cache.update(input_pos, k, v)
224-
mask = self.mask[None, None, input_pos]
225+
attn_mask = mask[None, None, input_pos]
225226

226227
k = k.repeat_interleave(self.n_rep, dim=1)
227228
v = v.repeat_interleave(self.n_rep, dim=1)
228-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
229+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0)
229230

230231
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
231232

@@ -271,10 +272,10 @@ def __init__(self, args: ModelArgs, layer_id: int):
271272
not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v
272273
)
273274
self.SDPA = SDPA(
274-
self.kv_cache,
275-
self.mask,
276-
self.dim,
277-
self.n_rep,
275+
kv_cache=self.kv_cache,
276+
dim=self.dim,
277+
head_dim=self.head_dim,
278+
n_rep=self.n_rep,
278279
)
279280

280281
def forward(
@@ -298,7 +299,7 @@ def forward(
298299

299300
if self.use_kv_cache:
300301
assert input_pos is not None
301-
output = self.SDPA(input_pos, q, k, v, bsz, seqlen)
302+
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, self.mask)
302303
return self.wo(output)
303304

304305
q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

0 commit comments

Comments
 (0)