Skip to content

Commit b12a901

Browse files
FightingZhenMekkCyber
authored andcommitted
[performance_optim] define flash attention mask on NPU device directly (huggingface#37698)
Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 6bb94d3 commit b12a901

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def npu_flash_attn_func(
171171
head_num = q.shape[2]
172172
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
173173
else:
174-
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
174+
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
175175
head_num = q.shape[2]
176176
output = torch_npu.npu_fusion_attention(
177177
q,
@@ -222,7 +222,7 @@ def npu_flash_attn_varlen_func(
222222
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
223223
)[0]
224224
else:
225-
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
225+
attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
226226
head_num = q.shape[1]
227227
output = torch_npu.npu_fusion_attention(
228228
q,

0 commit comments

Comments
 (0)