@@ -68,6 +68,24 @@ def _get_device_capability():
6868def _get_use_clc_scheduler_default ():
6969 return os .environ .get ("FA4_CLC" , "0" ) == "1"
7070
71+
72+ def _validate_head_dims (head_dim : int , head_dim_v : int , compute_capability : int ) -> None :
73+ """Validate head dimension constraints based on compute capability."""
74+ is_deepseek_shape = head_dim == 192 and head_dim_v == 128
75+ is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
76+
77+ if compute_capability == 9 :
78+ assert is_standard_range , (
79+ f"(head_dim, head_dim_v)=({ head_dim } , { head_dim_v } ) is not supported on SM90. "
80+ f"head_dim and head_dim_v must be between 8 and 128."
81+ )
82+ elif compute_capability in [10 , 11 ]:
83+ assert is_standard_range or is_deepseek_shape , (
84+ f"(head_dim, head_dim_v)=({ head_dim } , { head_dim_v } ) is not supported on SM100/SM110. "
85+ f"head_dim and head_dim_v must be between 8 and 128, or (192, 128) for DeepSeek."
86+ )
87+
88+
7189def maybe_contiguous (x ):
7290 return x .contiguous () if x is not None and x .stride (- 1 ) != 1 else x
7391
@@ -218,11 +236,17 @@ def _flash_attn_fwd(
218236 learnable_sink ,
219237 )
220238 ), "inputs must be on CUDA device"
239+ compute_capability = (
240+ _get_device_capability ()
241+ if _compute_capability is None
242+ else _compute_capability
243+ )
244+ assert compute_capability in [9 , 10 , 11 ], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
221245 assert num_head % num_head_kv == 0 , "num_head must be divisible by num_head_kv"
222- assert head_dim <= 256 , "head_dim must be less than or equal to 256"
223246 alignment = 16 // q .element_size ()
224247 assert head_dim % alignment == 0 , f"head_dim must be divisible by { alignment } "
225248 assert head_dim_v % alignment == 0 , f"head_dim_v must be divisible by { alignment } "
249+ _validate_head_dims (head_dim , head_dim_v , compute_capability )
226250 if softmax_scale is None :
227251 softmax_scale = 1.0 / math .sqrt (head_dim )
228252 if softcap == 0.0 :
@@ -254,14 +278,6 @@ def _flash_attn_fwd(
254278 _validate_tensor (lse , "lse" , lse_shape , torch .float32 , device )
255279
256280 dtype = torch2cute_dtype_map [q .dtype ]
257- compute_capability = (
258- _get_device_capability ()
259- if _compute_capability is None
260- else _compute_capability
261- )
262-
263- assert compute_capability in [9 , 10 , 11 ], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
264-
265281 use_block_sparsity = block_sparse_tensors is not None
266282
267283 if mask_mod is None :
@@ -712,10 +728,10 @@ def _flash_attn_bwd(
712728 t is None or t .is_cuda for t in (q , k , v , out , dout , lse , cu_seqlens_q , cu_seqlens_k )
713729 ), "inputs must be on CUDA device"
714730 assert num_head % num_head_kv == 0 , "num_head must be divisible by num_head_kv"
715- assert head_dim <= 256 , "head_dim must be less than or equal to 256"
716731 alignment = 16 // q .element_size ()
717732 assert head_dim % alignment == 0 , f"head_dim must be divisible by { alignment } "
718733 assert head_dim_v % alignment == 0 , f"head_dim_v must be divisible by { alignment } "
734+ _validate_head_dims (head_dim , head_dim_v , compute_capability )
719735 if softmax_scale is None :
720736 softmax_scale = 1.0 / math .sqrt (head_dim )
721737 qhead_per_kvhead = num_head // num_head_kv
0 commit comments