Skip to content

Commit db53912

Browse files
committed
bug fix: use flash-attn3
1 parent 8110cb8 commit db53912

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

  • records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli

records/track_10min_16mb/2026-04-05_11L_LatentMaskTTT_GPTQ_ProductKey_Brotli/train_gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch.nn.functional as F
1717
from torch import Tensor, nn
1818
from torch.nn.parallel import DistributedDataParallel as DDP
19-
from flash_attn.flash_attn_interface import flash_attn_func as _fa3_func
19+
from flash_attn_interface import flash_attn_func as _fa3_func
2020
def flash_attn_3_func(q, k, v, causal=True):
2121
return _fa3_func(q, k, v, causal=causal)
2222
class Hyperparameters:

0 commit comments

Comments
 (0)