1717from dataclasses import dataclass
1818import functools
1919import math
20+ import os
2021from typing import List , Literal , Optional , Tuple , Union , overload
2122
2223import torch
@@ -233,10 +234,13 @@ def _normalize_dsv4_topk_lens(
233234 q_len_per_request : int ,
234235 sum_seq_q : int ,
235236 name : str ,
237+ device : torch .device ,
236238 cum_seq_lens_q : Optional [torch .Tensor ] = None ,
237239) -> torch .Tensor :
238240 if topk_lens .dtype != torch .int32 :
239241 raise ValueError (f"{ name } must have dtype torch.int32, got { topk_lens .dtype } " )
242+ if topk_lens .device != device :
243+ raise ValueError (f"{ name } must be on device { device } , got { topk_lens .device } " )
240244 if topk_lens .ndim != 1 :
241245 raise ValueError (f"Expected flattened { name } .ndim == 1, got { topk_lens .ndim } " )
242246 if topk_lens .size (0 ) != sum_seq_q :
@@ -249,12 +253,16 @@ def _normalize_dsv4_topk_lens(
249253 cum_seq_lens_q ,
250254 (batch_size + 1 ,),
251255 torch .int32 ,
252- cum_seq_lens_q . device ,
256+ device ,
253257 "cum_seq_lens_q" ,
254258 )
255259 return topk_lens
256260
257261
262+ def _validate_dsv4_sync_checks () -> bool :
263+ return os .environ .get ("FLASHINFER_VALIDATE_INPUTS" , "0" ) not in ("0" , "" )
264+
265+
258266def _check_dsv4_sparse_mla_inputs (
259267 query : torch .Tensor ,
260268 swa_kv_cache : torch .Tensor ,
@@ -277,13 +285,16 @@ def _check_dsv4_sparse_mla_inputs(
277285 Optional [torch .Tensor ],
278286]:
279287 is_varlen_q = cum_seq_lens_q is not None
288+ out_shape : Tuple [int , ...]
289+ sparse_indices_prefix_shape : Tuple [int , ...]
280290 if is_varlen_q :
281291 if query .ndim != 3 :
282292 raise ValueError (
283293 "Expected query.ndim == 3 when cum_seq_lens_q is provided, "
284294 f"got { query .ndim } "
285295 )
286- assert cum_seq_lens_q is not None
296+ if cum_seq_lens_q is None :
297+ raise ValueError ("cum_seq_lens_q is required for varlen query input" )
287298 if cum_seq_lens_q .dtype != torch .int32 :
288299 raise ValueError (
289300 f"cum_seq_lens_q must have dtype torch.int32, got { cum_seq_lens_q .dtype } "
@@ -297,6 +308,11 @@ def _check_dsv4_sparse_mla_inputs(
297308 raise ValueError (
298309 f"Expected cum_seq_lens_q.numel() >= 2, got { cum_seq_lens_q .numel ()} "
299310 )
311+ if cum_seq_lens_q .device != query .device :
312+ raise ValueError (
313+ f"cum_seq_lens_q must be on query device { query .device } , "
314+ f"got { cum_seq_lens_q .device } "
315+ )
300316 sum_seq_q , num_heads , head_dim = query .shape
301317 if max_q_len is None :
302318 max_q_len = int ((cum_seq_lens_q [1 :] - cum_seq_lens_q [:- 1 ]).max ().item ())
@@ -331,7 +347,11 @@ def _check_dsv4_sparse_mla_inputs(
331347 if num_heads not in (64 , 128 ):
332348 raise ValueError (f"Expected 64 or 128 query heads, got { num_heads } " )
333349
334- if sparse_indices is None or compressed_kv_cache is None or sparse_topk_lens is None :
350+ if (
351+ sparse_indices is None
352+ or compressed_kv_cache is None
353+ or sparse_topk_lens is None
354+ ):
335355 raise ValueError (
336356 "sparse_indices, compressed_kv_cache, and sparse_topk_lens are required"
337357 )
@@ -346,8 +366,7 @@ def _check_dsv4_sparse_mla_inputs(
346366 )
347367 if sparse_indices .ndim != 2 :
348368 raise ValueError (
349- "Expected flattened sparse_indices.ndim == 2, got "
350- f"{ sparse_indices .ndim } "
369+ f"Expected flattened sparse_indices.ndim == 2, got { sparse_indices .ndim } "
351370 )
352371 if sparse_indices .shape [:- 1 ] != sparse_indices_prefix_shape :
353372 raise ValueError (
@@ -396,8 +415,9 @@ def _check_dsv4_sparse_mla_inputs(
396415 q_len_per_request ,
397416 sum_seq_q ,
398417 "sparse_topk_lens" ,
418+ query .device ,
399419 cum_seq_lens_q ,
400- ). to ( query . device )
420+ )
401421 if normalized_sparse_lens .numel () > 0 :
402422 min_sparse_len = int (normalized_sparse_lens .min ().item ())
403423 max_sparse_len = int (normalized_sparse_lens .max ().item ())
@@ -419,9 +439,6 @@ def _check_dsv4_sparse_mla_inputs(
419439 sinks , (num_heads ,), torch .float32 , query .device , "sinks"
420440 )
421441
422- if cum_seq_lens_q is not None :
423- cum_seq_lens_q = cum_seq_lens_q .to (query .device )
424-
425442 return (
426443 swa_kv_cache ,
427444 compressed_kv_cache ,
@@ -515,10 +532,12 @@ def trtllm_batch_decode_sparse_mla_dsv4(
515532 if enable_pdl is None :
516533 enable_pdl = device_support_pdl (query .device )
517534 if isinstance (bmm1_scale , torch .Tensor ):
518- assert bmm1_scale .dtype == torch .float32
535+ if bmm1_scale .dtype != torch .float32 :
536+ raise TypeError ("bmm1_scale tensor must have dtype torch.float32" )
519537 bmm1_scale = bmm1_scale * log2e
520538 if isinstance (bmm2_scale , torch .Tensor ):
521- assert bmm2_scale .dtype == torch .float32
539+ if bmm2_scale .dtype != torch .float32 :
540+ raise TypeError ("bmm2_scale tensor must have dtype torch.float32" )
522541
523542 (
524543 swa_kv_cache ,
@@ -529,37 +548,30 @@ def trtllm_batch_decode_sparse_mla_dsv4(
529548 query_flat ,
530549 expected_out_shape ,
531550 cum_seq_lens_q ,
532- ) = (
533- _check_dsv4_sparse_mla_inputs (
534- query ,
535- swa_kv_cache ,
536- sparse_indices ,
537- compressed_kv_cache ,
538- sparse_topk_lens ,
539- out ,
540- sinks ,
541- kv_layout ,
542- cum_seq_lens_q ,
543- max_q_len ,
544- )
551+ ) = _check_dsv4_sparse_mla_inputs (
552+ query ,
553+ swa_kv_cache ,
554+ sparse_indices ,
555+ compressed_kv_cache ,
556+ sparse_topk_lens ,
557+ out ,
558+ sinks ,
559+ kv_layout ,
560+ cum_seq_lens_q ,
561+ max_q_len ,
545562 )
546563
547564 if out is None :
548565 out = torch .empty (expected_out_shape , dtype = torch .bfloat16 , device = query .device )
549566
550- if seq_lens is None :
551- raise ValueError (
552- "seq_lens is required for DeepSeek V4 sparse MLA because TRTLLM-GEN "
553- "uses it to mask the fixed SWA-128 tile"
554- )
555567 check_shape_dtype_device (
556568 seq_lens , (batch_size ,), torch .int32 , query .device , "seq_lens"
557569 )
558570 if cum_seq_lens_q is None :
559571 q_lens = seq_lens .new_full ((batch_size ,), q_len_per_request )
560572 else :
561573 q_lens = cum_seq_lens_q [1 :] - cum_seq_lens_q [:- 1 ]
562- if torch .any (seq_lens < q_lens ).item ():
574+ if _validate_dsv4_sync_checks () and torch .any (seq_lens < q_lens ).item ():
563575 raise ValueError (
564576 "seq_lens must be greater than or equal to the per-request query "
565577 "lengths so TRTLLM-GEN can derive the SWA-128 valid window"
@@ -1120,10 +1132,12 @@ def trtllm_batch_decode_with_kv_cache_mla(
11201132 "trtllm-gen" if get_compute_capability (query .device )[0 ] == 10 else "xqa"
11211133 )
11221134 if isinstance (bmm1_scale , torch .Tensor ):
1123- assert bmm1_scale .dtype == torch .float32
1135+ if bmm1_scale .dtype != torch .float32 :
1136+ raise TypeError ("bmm1_scale tensor must have dtype torch.float32" )
11241137 bmm1_scale = bmm1_scale * log2e
11251138 if isinstance (bmm2_scale , torch .Tensor ):
1126- assert bmm2_scale .dtype == torch .float32
1139+ if bmm2_scale .dtype != torch .float32 :
1140+ raise TypeError ("bmm2_scale tensor must have dtype torch.float32" )
11271141 if backend == "xqa" :
11281142 if not is_sm12x_supported (query .device ):
11291143 raise ValueError (
0 commit comments