Skip to content

Commit e68e8b0

Browse files
committed
[Cute] Update remaining make_fragment_like to make_rmem_tensor_like
1 parent a2e922d commit e68e8b0

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

flash_attn/cute/flash_bwd_sm90.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,7 +1344,7 @@ def mma_one_m_block(
13441344
acc_dP = mma_dov_fn(A_idx=smem_idx_Q, wg_wait=1)
13451345

13461346
if const_expr(self.score_mod_bwd is not None):
1347-
acc_S_pre = cute.make_fragment_like(acc_S)
1347+
acc_S_pre = cute.make_rmem_tensor_like(acc_S)
13481348
cute.autovec_copy(acc_S, acc_S_pre)
13491349

13501350
if const_expr(self.score_mod is not None):
@@ -1498,7 +1498,7 @@ def epilogue_dKV(
14981498
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
14991499

15001500
if const_expr(self.qhead_per_kvhead == 1):
1501-
rdV = cute.make_fragment_like(acc_dV, self.dtype)
1501+
rdV = cute.make_rmem_tensor_like(acc_dV, self.dtype)
15021502
rdV.store(acc_dV.load().to(self.dtype))
15031503
rdK = utils.cvt_f16(acc_dK, self.dtype)
15041504

flash_attn/cute/paged_kv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def load_KV(self, n_block: Int32, sX: cute.Tensor, K_or_V: str):
153153
seqlenk_row_limit = self.seqlen_k - n_block * self.n_block_size if n_block >= 0 else 0
154154
for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])):
155155
row_valid = tXcX[0, m, 0][0] < seqlenk_row_limit
156-
should_load = cute.make_fragment_like(tXsX[None, m, 0], cute.Boolean)
156+
should_load = cute.make_rmem_tensor_like(tXsX[None, m, 0], cute.Boolean)
157157
should_load.fill(row_valid)
158158

159159
page = self.tPrPage[m]

0 commit comments

Comments
 (0)