@@ -63,6 +63,24 @@ def _get_device_capability():
6363 """Cached device capability check."""
6464 return torch .cuda .get_device_capability ()[0 ]
6565
66+
67+ def _validate_head_dims (head_dim : int , head_dim_v : int , compute_capability : int ) -> None :
68+ """Validate head dimension constraints based on compute capability."""
69+ is_deepseek_shape = head_dim == 192 and head_dim_v == 128
70+ is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
71+
72+ if compute_capability == 9 :
73+ assert is_standard_range , (
74+ f"(head_dim, head_dim_v)=({ head_dim } , { head_dim_v } ) is not supported on SM90. "
75+ f"head_dim and head_dim_v must be between 8 and 128."
76+ )
77+ elif compute_capability in [10 , 11 ]:
78+ assert is_standard_range or is_deepseek_shape , (
79+ f"(head_dim, head_dim_v)=({ head_dim } , { head_dim_v } ) is not supported on SM100/SM110. "
80+ f"head_dim and head_dim_v must be between 8 and 128, or (192, 128) for DeepSeek."
81+ )
82+
83+
6684def maybe_contiguous (x ):
6785 return x .contiguous () if x is not None and x .stride (- 1 ) != 1 else x
6886
@@ -211,11 +229,17 @@ def _flash_attn_fwd(
211229 learnable_sink ,
212230 )
213231 ), "inputs must be on CUDA device"
232+ compute_capability = (
233+ _get_device_capability ()
234+ if _compute_capability is None
235+ else _compute_capability
236+ )
237+ assert compute_capability in [9 , 10 , 11 ], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
214238 assert num_head % num_head_kv == 0 , "num_head must be divisible by num_head_kv"
215- assert head_dim <= 256 , "head_dim must be less than or equal to 256"
216239 alignment = 16 // q .element_size ()
217240 assert head_dim % alignment == 0 , f"head_dim must be divisible by { alignment } "
218241 assert head_dim_v % alignment == 0 , f"head_dim_v must be divisible by { alignment } "
242+ _validate_head_dims (head_dim , head_dim_v , compute_capability )
219243 if softmax_scale is None :
220244 softmax_scale = 1.0 / math .sqrt (head_dim )
221245 if softcap == 0.0 :
@@ -247,14 +271,6 @@ def _flash_attn_fwd(
247271 _validate_tensor (lse , "lse" , lse_shape , torch .float32 , device )
248272
249273 dtype = torch2cute_dtype_map [q .dtype ]
250- compute_capability = (
251- _get_device_capability ()
252- if _compute_capability is None
253- else _compute_capability
254- )
255-
256- assert compute_capability in [9 , 10 , 11 ], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
257-
258274 use_block_sparsity = block_sparse_tensors is not None
259275
260276 if mask_mod is None :
@@ -698,10 +714,10 @@ def _flash_attn_bwd(
698714 t is None or t .is_cuda for t in (q , k , v , out , dout , lse , cu_seqlens_q , cu_seqlens_k )
699715 ), "inputs must be on CUDA device"
700716 assert num_head % num_head_kv == 0 , "num_head must be divisible by num_head_kv"
701- assert head_dim <= 256 , "head_dim must be less than or equal to 256"
702717 alignment = 16 // q .element_size ()
703718 assert head_dim % alignment == 0 , f"head_dim must be divisible by { alignment } "
704719 assert head_dim_v % alignment == 0 , f"head_dim_v must be divisible by { alignment } "
720+ _validate_head_dims (head_dim , head_dim_v , compute_capability )
705721 if softmax_scale is None :
706722 softmax_scale = 1.0 / math .sqrt (head_dim )
707723 qhead_per_kvhead = num_head // num_head_kv
0 commit comments