@@ -55,8 +55,13 @@ def __init__(
5555
5656 @staticmethod
5757 def can_implement (
58- dtype , dtype_partial , head_dim , m_block_size , k_block_size ,
59- log_max_splits , num_threads ,
58+ dtype ,
59+ dtype_partial ,
60+ head_dim ,
61+ m_block_size ,
62+ k_block_size ,
63+ log_max_splits ,
64+ num_threads ,
6065 ) -> bool :
6166 """Check if the kernel can be implemented with the given parameters."""
6267 if dtype not in [cutlass .Float16 , cutlass .BFloat16 , cutlass .Float32 ]:
@@ -83,8 +88,7 @@ def _setup_attributes(self):
8388 assert self .k_block_size % async_copy_elems == 0
8489
8590 k_block_gmem = (
86- 128 if self .k_block_size % 128 == 0 else
87- (64 if self .k_block_size % 64 == 0 else 32 )
91+ 128 if self .k_block_size % 128 == 0 else (64 if self .k_block_size % 64 == 0 else 32 )
8892 )
8993 gmem_threads_per_row = k_block_gmem // async_copy_elems
9094 assert self .num_threads % gmem_threads_per_row == 0
@@ -111,16 +115,25 @@ def _setup_attributes(self):
111115 num_bits_per_copy = async_copy_elems * self .dtype .width ,
112116 )
113117 self .gmem_tiled_copy_O = cute .make_tiled_copy_tv (
114- atom_universal_copy , tOpartial_layout , vOpartial_layout # 4 vals per store
118+ atom_universal_copy ,
119+ tOpartial_layout ,
120+ vOpartial_layout , # 4 vals per store
115121 )
116122
117123 # LSE copy setup with async copy (alignment = 1)
118124 lse_copy_bits = Float32 .width # 1 element per copy, width is in bits
119125 m_block_smem = (
120- 128 if self .m_block_size % 128 == 0 else
121- (64 if self .m_block_size % 64 == 0 else
122- (32 if self .m_block_size % 32 == 0 else
123- (16 if self .m_block_size % 16 == 0 else 8 )))
126+ 128
127+ if self .m_block_size % 128 == 0
128+ else (
129+ 64
130+ if self .m_block_size % 64 == 0
131+ else (
132+ 32
133+ if self .m_block_size % 32 == 0
134+ else (16 if self .m_block_size % 16 == 0 else 8 )
135+ )
136+ )
124137 )
125138 gmem_threads_per_row_lse = m_block_smem
126139 assert self .num_threads % gmem_threads_per_row_lse == 0
@@ -167,21 +180,17 @@ def _setup_attributes(self):
167180 else :
168181 smem_lse_swizzle = cute .make_swizzle (3 , 2 , 3 )
169182 smem_layout_atom_lse = cute .make_composed_layout (
170- smem_lse_swizzle ,
171- 0 ,
172- cute .make_ordered_layout ((8 , m_block_smem ), order = (1 , 0 ))
183+ smem_lse_swizzle , 0 , cute .make_ordered_layout ((8 , m_block_smem ), order = (1 , 0 ))
173184 )
174185 self .smem_layout_lse = cute .tile_to_shape (
175186 smem_layout_atom_lse , (self .max_splits , self .m_block_size ), (0 , 1 )
176187 )
177188
178189 # O partial shared memory layout (simple layout for pipeline stages)
179190 self .smem_layout_o = cute .make_ordered_layout (
180- (self .m_block_size , self .k_block_size , self .stages ),
181- order = (1 , 0 , 2 )
191+ (self .m_block_size , self .k_block_size , self .stages ), order = (1 , 0 , 2 )
182192 )
183193
184-
185194 @cute .jit
186195 def __call__ (
187196 self ,
@@ -200,38 +209,63 @@ def __call__(
200209 raise TypeError ("O partial tensor must match dtype_partial" )
201210 if const_expr (not (mO .element_type == self .dtype )):
202211 raise TypeError ("O tensor must match dtype" )
203- if const_expr (not mLSE_partial .element_type in [Float32 ]):
212+ if const_expr (mLSE_partial .element_type not in [Float32 ]):
204213 raise TypeError ("LSE partial tensor must be Float32" )
205- if const_expr (mLSE is not None and not mLSE .element_type in [Float32 ]):
214+ if const_expr (mLSE is not None and mLSE .element_type not in [Float32 ]):
206215 raise TypeError ("LSE tensor must be Float32" )
207216
208217 # Shape validation - input tensors are in user format, need to be converted to kernel format
209218 if const_expr (len (mO_partial .shape ) not in [4 , 5 ]):
210- raise ValueError ("O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)" )
219+ raise ValueError (
220+ "O partial tensor must have 4 or 5 dimensions: (num_splits, batch, seqlen, nheads, headdim) or (num_splits, total_q, nheads, headdim)"
221+ )
211222 if const_expr (len (mLSE_partial .shape ) not in [3 , 4 ]):
212- raise ValueError ("LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)" )
223+ raise ValueError (
224+ "LSE partial tensor must have 3 or 4 dimensions: (num_splits, batch, seqlen, nheads) or (num_splits, total_q, nheads)"
225+ )
213226 if const_expr (len (mO .shape ) not in [3 , 4 ]):
214- raise ValueError ("O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)" )
227+ raise ValueError (
228+ "O tensor must have 3 or 4 dimensions: (batch, seqlen, nheads, headdim) or (total_q, nheads, headdim)"
229+ )
215230 if const_expr (mLSE is not None and len (mLSE .shape ) not in [2 , 3 ]):
216- raise ValueError ("LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)" )
231+ raise ValueError (
232+ "LSE tensor must have 2 or 3 dimensions: (batch, seqlen, nheads) or (total_q, nheads)"
233+ )
217234
218235 # Assume all strides are divisible by 128 bits except the last stride
219- new_stride = lambda t : (* (cute .assume (s , divby = 128 // t .element_type .width ) for s in t .stride [:- 1 ]), t .stride [- 1 ])
220- mO_partial , mO = [cute .make_tensor (t .iterator , cute .make_layout (t .shape , stride = new_stride (t ))) for t in (mO_partial , mO )]
236+ new_stride = lambda t : (
237+ * (cute .assume (s , divby = 128 // t .element_type .width ) for s in t .stride [:- 1 ]),
238+ t .stride [- 1 ],
239+ )
240+ mO_partial , mO = [
241+ cute .make_tensor (t .iterator , cute .make_layout (t .shape , stride = new_stride (t )))
242+ for t in (mO_partial , mO )
243+ ]
221244 # (num_splits, b, seqlen, h, d) -> (seqlen, d, num_splits, h, b)
222245 # or (num_splits, total_q, h, d) -> (total_q, d, num_splits, h)
223- O_partial_layout_transpose = [2 , 4 , 0 , 3 , 1 ] if const_expr (cu_seqlens is None ) else [1 , 3 , 0 , 2 ]
246+ O_partial_layout_transpose = (
247+ [2 , 4 , 0 , 3 , 1 ] if const_expr (cu_seqlens is None ) else [1 , 3 , 0 , 2 ]
248+ )
224249 # (b, seqlen, h, d) -> (seqlen, d, h, b) or (total_q, h, d) -> (total_q, d, h)
225- mO_partial = cute .make_tensor (mO_partial .iterator , cute .select (mO_partial .layout , mode = O_partial_layout_transpose ))
250+ mO_partial = cute .make_tensor (
251+ mO_partial .iterator , cute .select (mO_partial .layout , mode = O_partial_layout_transpose )
252+ )
226253 O_layout_transpose = [1 , 3 , 2 , 0 ] if const_expr (cu_seqlens is None ) else [0 , 2 , 1 ]
227254 mO = cute .make_tensor (mO .iterator , cute .select (mO .layout , mode = O_layout_transpose ))
228255 # (num_splits, b, seqlen, h) -> (seqlen, num_splits, h, b)
229256 # or (num_splits, total_q, h) -> (total_q, num_splits, h)
230257 LSE_partial_layout_transpose = [2 , 0 , 3 , 1 ] if const_expr (cu_seqlens is None ) else [1 , 0 , 2 ]
231- mLSE_partial = cute .make_tensor (mLSE_partial .iterator , cute .select (mLSE_partial .layout , mode = LSE_partial_layout_transpose ))
258+ mLSE_partial = cute .make_tensor (
259+ mLSE_partial .iterator ,
260+ cute .select (mLSE_partial .layout , mode = LSE_partial_layout_transpose ),
261+ )
232262 # (b, seqlen, h) -> (seqlen, h, b) or (total_q, h) -> (total_q, h)
233263 LSE_layout_transpose = [1 , 2 , 0 ] if const_expr (cu_seqlens is None ) else [0 , 1 ]
234- mLSE = cute .make_tensor (mLSE .iterator , cute .select (mLSE .layout , mode = LSE_layout_transpose )) if mLSE is not None else None
264+ mLSE = (
265+ cute .make_tensor (mLSE .iterator , cute .select (mLSE .layout , mode = LSE_layout_transpose ))
266+ if mLSE is not None
267+ else None
268+ )
235269
236270 # Determine if we have variable length sequences
237271 varlen = const_expr (cu_seqlens is not None or seqused is not None )
@@ -243,9 +277,7 @@ class SharedStorage:
243277 sLSE : cute .struct .Align [
244278 cute .struct .MemRange [Float32 , cute .cosize (self .smem_layout_lse )], 128
245279 ]
246- sMaxValidSplit : cute .struct .Align [
247- cute .struct .MemRange [Int32 , self .m_block_size ], 128
248- ]
280+ sMaxValidSplit : cute .struct .Align [cute .struct .MemRange [Int32 , self .m_block_size ], 128 ]
249281 sO : cute .struct .Align [
250282 cute .struct .MemRange [self .dtype_partial , cute .cosize (self .smem_layout_o )], 128
251283 ]
@@ -255,7 +287,11 @@ class SharedStorage:
255287 # Grid dimensions: (ceil_div(seqlen, m_block), ceil_div(head_dim, k_block), num_head * batch)
256288 seqlen = mO_partial .shape [0 ]
257289 num_head = mO_partial .shape [3 ]
258- batch_size = mO_partial .shape [4 ] if const_expr (cu_seqlens is None ) else Int32 (cu_seqlens .shape [0 ] - 1 )
290+ batch_size = (
291+ mO_partial .shape [4 ]
292+ if const_expr (cu_seqlens is None )
293+ else Int32 (cu_seqlens .shape [0 ] - 1 )
294+ )
259295
260296 # Create FastDivmodDivisor objects for efficient division
261297 seqlen_divmod = FastDivmodDivisor (seqlen )
@@ -330,22 +366,26 @@ def kernel(
330366
331367 # Handle semaphore reset
332368 if const_expr (semaphore_to_reset is not None ):
333- if (tidx == 0 and m_block == cute .arch .grid_dim ()[0 ] - 1 and
334- k_block == cute .arch .grid_dim ()[1 ] - 1 and
335- batch_idx == cute .arch .grid_dim ()[2 ] - 1 ):
369+ if (
370+ tidx == 0
371+ and m_block == cute .arch .grid_dim ()[0 ] - 1
372+ and k_block == cute .arch .grid_dim ()[1 ] - 1
373+ and batch_idx == cute .arch .grid_dim ()[2 ] - 1
374+ ):
336375 semaphore_to_reset [0 ] = 0
337376
338377 # Get number of splits
339378 num_splits = (
340- num_splits_dynamic_ptr [batch_idx ] if const_expr (num_splits_dynamic_ptr is not None )
379+ num_splits_dynamic_ptr [batch_idx ]
380+ if const_expr (num_splits_dynamic_ptr is not None )
341381 else mLSE_partial .shape [1 ]
342382 )
343383 # Handle variable length sequences using SeqlenInfo
344384 seqlen_info = SeqlenInfo .create (
345385 batch_idx = batch_idx ,
346386 seqlen_static = mO_partial .shape [0 ],
347387 cu_seqlens = cu_seqlens ,
348- seqused = seqused
388+ seqused = seqused ,
349389 )
350390 seqlen , offset = seqlen_info .seqlen , seqlen_info .offset
351391
@@ -354,8 +394,9 @@ def kernel(
354394 max_idx = seqlen * num_head
355395
356396 # Early exit for single split if dynamic
357- if (const_expr (num_splits_dynamic_ptr is None ) or num_splits > 1 ) and (const_expr (not varlen ) or m_block * self .m_block_size < max_idx ):
358-
397+ if (const_expr (num_splits_dynamic_ptr is None ) or num_splits > 1 ) and (
398+ const_expr (not varlen ) or m_block * self .m_block_size < max_idx
399+ ):
359400 # ===============================
360401 # Step 1: Load LSE_partial from gmem to shared memory
361402 # ===============================
@@ -390,7 +431,11 @@ def kernel(
390431 for s in cutlass .range (cute .size (tLSEcLSE , mode = [1 ]), unroll_full = True ):
391432 si = tLSEcLSE [0 , s , 0 ][0 ] # Get split coordinate
392433 if si < num_splits :
393- cute .copy (gmem_thr_copy_LSE , mLSE_partial_cur_copy [None , si ], tLSEsLSE [None , s , m ])
434+ cute .copy (
435+ gmem_thr_copy_LSE ,
436+ mLSE_partial_cur_copy [None , si ],
437+ tLSEsLSE [None , s , m ],
438+ )
394439 else :
395440 tLSEsLSE [None , s , m ].fill (- Float32 .inf )
396441 # Don't need to zero out the rest of the LSEs, as we will not write the output to gmem
@@ -424,7 +469,9 @@ def kernel(
424469 else :
425470 tOhidx [m ] = idx // seqlen
426471 tOmidx [m ] = idx - tOhidx [m ] * seqlen
427- tOrOptr [m ] = utils .elem_pointer_i64 (mO_partial_cur , (tOmidx [m ], k_block * self .k_block_size , 0 , tOhidx [m ])).toint ()
472+ tOrOptr [m ] = utils .elem_pointer_i64 (
473+ mO_partial_cur , (tOmidx [m ], k_block * self .k_block_size , 0 , tOhidx [m ])
474+ ).toint ()
428475 if idx >= max_idx :
429476 tOhidx [m ] = - 1
430477
@@ -483,7 +530,9 @@ def kernel(
483530 # Find max LSE value across splits
484531 threads_per_col = const_expr (self .smem_threads_per_col_lse )
485532 lse_max = utils .warp_reduce (
486- ts2rrLSE [None , None , m ].load ().reduce (cute .ReductionOp .MAX , init_val = - Float32 .inf , reduction_profile = 0 ),
533+ ts2rrLSE [None , None , m ]
534+ .load ()
535+ .reduce (cute .ReductionOp .MAX , init_val = - Float32 .inf , reduction_profile = 0 ),
487536 op = cute .arch .fmax ,
488537 width = threads_per_col ,
489538 )
@@ -496,7 +545,9 @@ def kernel(
496545 # if cute.arch.thread_idx()[0] < 32: cute.printf(max_valid_idx)
497546 max_valid_split [m ] = utils .warp_reduce (max_valid_idx , max , width = threads_per_col )
498547 # Compute exp scales and sum
499- lse_max_cur = 0.0 if lse_max == - Float32 .inf else lse_max # In case all local LSEs are -inf
548+ lse_max_cur = (
549+ 0.0 if lse_max == - Float32 .inf else lse_max
550+ ) # In case all local LSEs are -inf
500551 LOG2_E = math .log2 (math .e )
501552 lse_sum_cur = 0.0
502553 for s in cutlass .range (cute .size (ts2rrLSE , mode = [1 ]), unroll_full = True ):
@@ -506,7 +557,9 @@ def kernel(
506557 lse_sum_cur = utils .warp_reduce (lse_sum_cur , operator .add , width = threads_per_col )
507558 lse_sum [m ] = utils .logf (lse_sum_cur ) + lse_max
508559 # Normalize scales
509- inv_sum = 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur ) else 1.0 / lse_sum_cur
560+ inv_sum = (
561+ 0.0 if (lse_sum_cur == 0.0 or lse_sum_cur != lse_sum_cur ) else 1.0 / lse_sum_cur
562+ )
510563 ts2rrLSE [None , None , m ].store (ts2rrLSE [None , None , m ].load () * inv_sum )
511564 # Store the scales exp(lse - lse_logsum) back to smem
512565 cute .copy (s2r_tiled_copy_LSE , ts2rrLSE , ts2rsLSE )
@@ -584,7 +637,10 @@ def kernel(
584637 # Accumulate scaled partial results
585638 for m in cutlass .range (num_rows , unroll_full = True ):
586639 if tOhidx [m ] >= 0 and scale [m ] > 0.0 :
587- tOrO [None , m , None ].store (tOrO [None , m , None ].load () + scale [m ] * tOrO_partial [None , m , None ].load ().to (Float32 ))
640+ tOrO [None , m , None ].store (
641+ tOrO [None , m , None ].load ()
642+ + scale [m ] * tOrO_partial [None , m , None ].load ().to (Float32 )
643+ )
588644
589645 # ===============================
590646 # Step 7: Write final O to gmem
@@ -605,7 +661,9 @@ def kernel(
605661 # Write final results
606662 for m in cutlass .range (num_rows , unroll_full = True ):
607663 if tOhidx [m ] >= 0 :
608- mO_cur_copy = cute .tiled_divide (mO_cur [tOmidx [m ], None , tOhidx [m ]], (elems_per_store ,))
664+ mO_cur_copy = cute .tiled_divide (
665+ mO_cur [tOmidx [m ], None , tOhidx [m ]], (elems_per_store ,)
666+ )
609667 for k in cutlass .range (cute .size (tOcO , mode = [2 ]), unroll_full = True ):
610668 k_idx = tOcO [0 , 0 , k ][1 ] // elems_per_store
611669 if const_expr (self .is_even_k ) or tOpO [k ]:
@@ -631,7 +689,9 @@ def load_O_partial(
631689 o_gmem_ptr = cute .make_ptr (
632690 tOsO_partial .element_type , tOrOptr [m ], cute .AddressSpace .gmem , assumed_align = 16
633691 )
634- mO_partial_cur = cute .make_tensor (o_gmem_ptr , cute .slice_ (mO_cur_partial_layout , (0 , None , None , 0 )))
692+ mO_partial_cur = cute .make_tensor (
693+ o_gmem_ptr , cute .slice_ (mO_cur_partial_layout , (0 , None , None , 0 ))
694+ )
635695 mO_partial_cur_copy = cute .tiled_divide (mO_partial_cur , (elems_per_load ,))
636696 for k in cutlass .range (cute .size (tOcO , mode = [2 ]), unroll_full = True ):
637697 k_idx = tOcO [0 , 0 , k ][1 ] // elems_per_load
@@ -640,5 +700,5 @@ def load_O_partial(
640700 gmem_tiled_copy_O_partial ,
641701 # mO_partial_cur_copy[None, k_idx, split],
642702 utils .coord_offset_i64 (mO_partial_cur_copy , split , dim = 2 )[None , k_idx ],
643- tOsO_partial_cur [None , m , k ]
703+ tOsO_partial_cur [None , m , k ],
644704 )
0 commit comments