feat: add get flashinfer-trace interface .fi_trace#2931
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis change adds a TraceTemplate-based tracing system, fi_trace generation and registration, attaches trace templates via an extended Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API as flashinfer_api wrapper
participant Template as TraceTemplate / dispatcher
participant FiTrace as fi_trace builder
participant FS as Filesystem
Client->>API: call decorated function (possibly trace=callable)
API->>Template: resolve trace_template (dispatch or static)
Template-->>FiTrace: build fi_trace_fn (bind reference, axes, inputs/outputs)
API->>FiTrace: if tracing enabled -> invoke fi_trace_fn(**bound_args)
FiTrace->>FS: write <name>.json (if save_dir or env enabled)
FiTrace-->>API: return trace dict
API-->>Client: execute original function and return result
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds the @flashinfer_api decorator to multiple classes and functions across the library, including attention and decode wrappers as well as GEMM execution utilities, to enable API logging. The review feedback points out that applying this decorator to subclasses whose base classes are already decorated results in redundant log entries. Additionally, nested calls between decorated functions may lead to duplicate logging, suggesting that the logging logic should handle re-entrancy or that certain decorators should be removed to reduce overhead.
| @@ -209,6 +209,7 @@ class BatchAttentionWithAttentionSinkWrapper(BatchPrefillWithPagedKVCacheWrapper | |||
| a convenient interface for using attention sinks during prefill or decode attention. | |||
| """ | |||
|
|
|||
| @flashinfer_api | |||
There was a problem hiding this comment.
Adding @flashinfer_api to BatchAttentionWithAttentionSinkWrapper.__init__ will result in double logging during initialization. This class inherits from BatchPrefillWithPagedKVCacheWrapper, whose __init__ method is already decorated with @flashinfer_api. Since the decorator uses the class name of the instance (args[0]), both the subclass and base class decorators will log an entry for BatchAttentionWithAttentionSinkWrapper.__init__. This redundancy clutters the logs and adds unnecessary overhead. Consider removing the decorator from the subclass if the base class logging is sufficient for your tracing needs.
| @@ -1546,6 +1546,7 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWra | |||
| :class:`BatchDecodeWithPagedKVCacheWrapper` | |||
| """ | |||
|
|
|||
| @flashinfer_api | |||
There was a problem hiding this comment.
Similar to the issue in BatchAttentionWithAttentionSinkWrapper, decorating CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ leads to redundant log entries because its base class BatchDecodeWithPagedKVCacheWrapper.__init__ is already decorated. Both will log as CUDAGraphBatchDecodeWithPagedKVCacheWrapper.__init__ due to how the decorator resolves the class name from the instance.
| @@ -116,6 +116,7 @@ def gemm_runner(): | |||
| ) | |||
|
|
|||
|
|
|||
| @flashinfer_api | |||
There was a problem hiding this comment.
Decorating trtllm_low_latency_gemm will cause double logging when it is called internally by other decorated APIs, such as mm_fp8 in flashinfer/gemm/gemm_base.py. While it is important to trace this function when called directly, the current logging implementation will produce redundant entries for nested calls. This should ideally be addressed in the logging decorator's logic to handle re-entrancy, but for now, be aware of the log duplication.
There was a problem hiding this comment.
Actionable comments posted: 2
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer/gdn_prefill.py (1)
86-100: 🛠️ Refactor suggestion | 🟠 MajorAdd backend capability gating on this SM90-only API.
chunk_gated_delta_ruledocuments an SM90 requirement but is not decorated with@backend_requirement. Please add the backend/capability gate alongside@flashinfer_api(...)so unsupported devices fail fast with a clear message.As per coding guidelines:
Use@backend_requirementdecorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_prefill.py` around lines 86 - 100, The function chunk_gated_delta_rule is SM90-only but lacks the backend capability guard; add the `@backend_requirement`(...) decorator alongside `@flashinfer_api`(trace=gdn_prefill_trace) to check is_backend_supported() and is_compute_capability_supported(cc) for SM90 and return a clear fail-fast message for unsupported devices. Use the decorator to declare the required compute capability (SM90) and backend, referencing chunk_gated_delta_rule so the check runs before execution and produces a helpful error if the device is not supported.flashinfer/trtllm_low_latency_gemm.py (1)
119-125: 🛠️ Refactor suggestion | 🟠 MajorAdd
@backend_requirementfor this Blackwell-only entrypoint.
trtllm_low_latency_gemmis documented as Blackwell-only, but the API is not gated with@backend_requirement. Please add the explicit capability/backend guard so callers get deterministic early validation.As per coding guidelines:
Use@backend_requirementdecorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trtllm_low_latency_gemm.py` around lines 119 - 125, Add the `@backend_requirement` decorator to the trtllm_low_latency_gemm entrypoint to gate it to Blackwell-only execution: place `@backend_requirement`(...) immediately above the trtllm_low_latency_gemm definition and provide checks that call the module's support helpers (e.g., is_compute_capability_supported and is_backend_supported) or small wrapper functions that return True only for Blackwell compute capability/backend; ensure the decorator references the correct check functions so callers receive deterministic early validation for Blackwell-only usage of trtllm_low_latency_gemm.
🟠 Major comments (23)
flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json-32-38 (1)
32-38:⚠️ Potential issue | 🟠 MajorGEMM reference uses an incompatible transpose with declared shapes.
With
A: [M, K]andB: [K, N](Line 32–38), Line 66 should computeA @ B, notA @ B.T. Current reference is dimensionally inconsistent.Suggested fix
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)\n" + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)\n"Also applies to: 66-66
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.json` around lines 32 - 38, The GEMM reference is using an incompatible transpose for tensor B given the declared shapes "A": [M,K] and "B": [K,N]; update the computation that currently multiplies A by B.T to multiply A by B instead so the operation becomes A @ B (ensure the result shape is [M,N]), and verify any accompanying description/metadata (e.g., keys "A", "B" and dtype "float8_e4m3fn") and comments reflect no transpose on B.flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json-38-43 (1)
38-43:⚠️ Potential issue | 🟠 MajorOutput dtype is inconsistent with the traced API contract.
Line 42 declares
samplesasint64, but this API path returnsint32by default (whenindicesis not provided). The reference in Line 46 also allocatesint64, so both schema and reference are misaligned.Suggested fix
- "dtype": "int64", + "dtype": "int32",- samples = torch.empty(batch_size, dtype=torch.int64, device=device)\n + samples = torch.empty(batch_size, dtype=torch.int32, device=device)\nAlso applies to: 46-46
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.json` around lines 38 - 43, The schema and reference for the "samples" field in top_k_sampling_v128256.json incorrectly use dtype "int64" while the API returns int32 by default; change the "samples" dtype from "int64" to "int32" in the JSON schema and update the corresponding reference allocation that currently creates int64 to allocate int32 instead (look for the "samples" field and any reference example/allocation near the "shape": ["batch_size"] and the later allocation on Line 46 to ensure both schema and example match int32).flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json-49-58 (1)
49-58:⚠️ Potential issue | 🟠 MajorReference return signature does not match declared outputs.
Line 49–56 declares two outputs (
output,residual), but Line 58’s reference returns only one tensor. This makes the trace definition internally inconsistent for validators/consumers.Suggested fix
- return y.to(hidden_states.dtype)\n" + return y.to(hidden_states.dtype), x.to(hidden_states.dtype)\n"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.json` around lines 49 - 58, The reference function _fused_add_rmsnorm_reference currently returns only the normalized output tensor but the trace declares two outputs ("output" and "residual"), so update the reference to return both values to match the schema: compute y as now and also produce the updated residual (residual + hidden_states in float32, cast back to residual.dtype) and return (y, updated_residual) (or alternatively change the trace outputs to a single "output" if the residual should not be returned); ensure names/ordering match the declared outputs.flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json-30-36 (1)
30-36:⚠️ Potential issue | 🟠 MajorBF16 GEMM reference is inconsistent with input shape declaration.
Given
Bshape[K, N](Line 30–36), Line 48 should not transposeBfor matmul. The current expression conflicts with the stated tensor contract.Suggested fix
- "reference": "def _mm_reference(A, B):\n return torch.matmul(A, B.T)\n" + "reference": "def _mm_reference(A, B):\n return torch.matmul(A, B)\n"Also applies to: 48-48
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.json` around lines 30 - 36, The JSON metadata declares tensor "B" with shape ["K","N"] (physical column-major [K, N]) but the matmul expression erroneously transposes B; update the matmul expression that currently uses B.T (or otherwise transposes "B") so it uses "B" directly to match the declared [K,N] contract, and ensure any accompanying description/comment is adjusted to reflect no transpose is applied.flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json-29-43 (1)
29-43:⚠️ Potential issue | 🟠 MajorFP4 GEMM schema and reference are shape-inconsistent.
Line 29–43 declares unpacked shapes (
[M, K],[K, N]) while Line 71 treatsA/Bas packed bytes and reconstructs logical dims by multiplying by 2. On top of that, the finalB_scaled.Tintroduces another dimension mismatch.Please make schema and reference consistent in one direction:
- keep packed semantics and declare packed shapes, or
- keep unpacked shapes and remove nibble-unpack logic.
Also, the final GEMM should not transposeB_scaledunder the current shape declarations.Also applies to: 71-71
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json` around lines 29 - 43, The schema declares A and B as packed uint8 tensors but later code treats them as unpacked nibbles and reconstructs logical dims (the nibble-unpack logic) and then does B_scaled.T which creates a mismatch; pick one consistent approach and fix both schema and code: either (A) declare A/B shapes as packed (bytes) and keep the nibble-unpack/reconstruction code that expands to logical shapes but remove the final transpose of B_scaled (or transpose before unpacking) so GEMM uses matching [M,K] and [K,N], or (B) declare A/B as unpacked shapes ([M,K], [K,N]) and remove the nibble-unpack/reconstruction entirely; update the "description" fields (fp4 e2m1fn_x2 packed as uint8) and references to B_scaled and its transpose to match the chosen convention (adjust usage of B_scaled.T accordingly).flashinfer/gdn_decode.py-349-350 (1)
349-350: 🛠️ Refactor suggestion | 🟠 MajorAdd
@backend_requirementon these SM-constrained public APIs.At Line 349 and Line 490, these APIs are decorated for tracing but still lack explicit backend capability guards at the API boundary.
As per coding guidelines: Use
@backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods.Also applies to: 490-491
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 349 - 350, The public API function gated_delta_rule_decode (decorated with `@flashinfer_api`(trace=gated_delta_rule_decode_trace)) is SM-constrained and must be guarded by the backend capability decorator; add `@backend_requirement` above its definition and implement the decorator to call the module's is_compute_capability_supported(cc) and is_backend_supported() helpers. Do the same for the other SM-constrained public API in this file that is currently decorated with `@flashinfer_api` around the later section (the second gated rule decode API at the other occurrence) so both API entrypoints check compute capability and backend support before proceeding.flashinfer/gdn_decode.py-36-53 (1)
36-53:⚠️ Potential issue | 🟠 MajorDecouple trace-template import from
flashinfer_apifallback.If trace template import fails but
flashinfer_apiis available, the current combinedtryblock still falls back to a no-op decorator, silently disabling API logging/tracing behavior.♻️ Suggested fix
-try: - from .api_logging import flashinfer_api - from .trace.templates.gdn import ( - gated_delta_rule_decode_trace, - gdn_mtp_trace, - ) - _FLASHINFER_AVAILABLE = True -except ImportError: - _FLASHINFER_AVAILABLE = False - gated_delta_rule_decode_trace = None # type: ignore[assignment] - gdn_mtp_trace = None # type: ignore[assignment] - - # Fallback decorator for standalone usage (accepts trace= kwarg) - def flashinfer_api(func=None, *, trace=None): # type: ignore[misc] - if func is None: - return lambda f: f - return func +try: + from .api_logging import flashinfer_api + _FLASHINFER_AVAILABLE = True +except ImportError: + _FLASHINFER_AVAILABLE = False + def flashinfer_api(func=None, *, trace=None): # type: ignore[misc] + if func is None: + return lambda f: f + return func + +try: + from .trace.templates.gdn import ( + gated_delta_rule_decode_trace, + gdn_mtp_trace, + ) +except ImportError: + gated_delta_rule_decode_trace = None # type: ignore[assignment] + gdn_mtp_trace = None # type: ignore[assignment]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gdn_decode.py` around lines 36 - 53, The combined try/except hides a missing trace-template import by replacing flashinfer_api with a no-op; split imports so flashinfer_api is imported in its own try/except and sets _FLASHINFER_AVAILABLE, then separately attempt to import gated_delta_rule_decode_trace and gdn_mtp_trace and only set them to None on failure—define the fallback flashinfer_api decorator only when the flashinfer_api import itself fails so trace import failures do not disable API logging/tracing.flashinfer/trace/templates/sampling.py-24-41 (1)
24-41:⚠️ Potential issue | 🟠 MajorThe sampling references are not reproducible as written.
These references call
torch.multinomial, but the template schema does not carry any RNG input, seed, or pre-generated random variate. The same trace payload can therefore emit differentsamplesacross runs, which makes the generated definitions unstable as reference artifacts. Please encode the randomness in the trace inputs or make the reference deterministic.Also applies to: 79-103, 141-173
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/sampling.py` around lines 24 - 41, The _top_k_sampling_reference function (and the other sampling reference blocks at lines 79-103 and 141-173) currently calls torch.multinomial which uses nondeterministic RNG not captured by the trace; change the reference signatures to accept explicit randomness (e.g., a per-sample uniform variates tensor or an RNG seed/tensor) and use those inputs to deterministically draw samples: after filtering/renormalizing the probabilities in _top_k_sampling_reference, compute the cumulative distribution and select the token whose cdf first exceeds the provided uniform variate for that batch (instead of torch.multinomial), and apply the same pattern to the other sampling reference functions so the randomness is fully encoded in trace inputs.flashinfer/api_logging.py-1497-1503 (1)
1497-1503:⚠️ Potential issue | 🟠 MajorDon’t silently disable
.fi_traceon attachment errors.These
except Exception: passblocks turn template/build failures into invisible feature loss: a broken trace template can quietly remove.fi_trace, andFLASHINFER_TRACE_DUMP=1can fail to write anything without surfacing why. Please preserve the failure in a stubfi_traceor emit a warning instead of dropping it on the floor.Also applies to: 1516-1517
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/api_logging.py` around lines 1497 - 1503, The current try/except around calling fi_trace_fn (guarded by _is_trace_dump_enabled and using _sig.bind(...)) swallows all exceptions and silently disables .fi_trace; change this to catch Exception as e, log or warn about the attachment/templating error (include the exception), and install a stub fi_trace function that preserves the attribute but emits the warning (or raises) when invoked so the feature failure is visible; apply the same change to the analogous block that appears for the second attachment (the other try/except using _sig.bind and fi_trace_fn).flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json-112-112 (1)
112-112:⚠️ Potential issue | 🟠 MajorThe decode reference mixes page IDs with token indices.
Line 112 declares
kv_indicesas page IDs, but the reference flattens the cache to[num_pages * page_size, ...]and indexes that flattened tensor directly with those IDs. Withpage_size=64, each selected page contributes only its first token tok_b/v_b, so the logits and outputs are wrong. Please fix the upstream template to index pages first, then flatten to tokens, and regenerate this artifact.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json` at line 112, The reference implementation _gqa_paged_decode_reference incorrectly treats kv_indices as token indices against k_flat/v_flat; instead treat kv_indices as page IDs: use kv_indptr to select pages, index k_cache/v_cache by page IDs to get per-page tensors, then reshape/flatten each selected page into tokens (or index within page using page_size) before computing k_b and v_b; update the logic that computes k_flat/v_flat (or remove flattening) so you first select pages via kv_indices[page_start:page_end] -> page_ids, then gather k_cache[page_ids] and v_cache[page_ids] and reshape to token dimension prior to matmuls, then regenerate the artifact.flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json-119-119 (1)
119-119:⚠️ Potential issue | 🟠 MajorSchema/reference mismatch for
kv_indices.Line 119 says
kv_indicesare page IDs, but the reference indexesk_cache.reshape(-1, ...)/v_cache.reshape(-1, ...)with those IDs and setsnum_kv_tokensfrom the number of pages. Withpage_size=16, that dropspage_size - 1tokens from every selected page, so the causal window and outputs are wrong for paged inputs. Please fix the source template to gather pages first, then flatten their token dimension, and regenerate this artifact.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json` at line 119, The reference treats kv_indices as page IDs but indexes k_cache/v_cache after reshaping into pages (k_flat/v_flat), which incorrectly drops the per-page token dimension; in _gqa_paged_prefill_reference gather the full pages first (use page_ids = kv_indices[kv_start:kv_end] to index k_cache and v_cache by page dimension), then flatten the page-token axis so k_b and v_b include all page_size tokens (adjust k_flat/v_flat usage or index k_cache/v_cache directly), set num_kv_tokens = page_ids.shape[0] * page_size, and update loops that compute max_kv, logits, attn, and output to iterate over the flattened token sequence accordingly (refer to symbols kv_indices, k_cache, v_cache, k_flat, v_flat, page_size, num_kv_tokens).flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json-123-123 (1)
123-123:⚠️ Potential issue | 🟠 MajorThis
ps64reference still assumes one token per page.Line 123 indexes
kv_indicesintockv_cache/kpe_cachewithout flattening the selectedpage_sizedimension first. In this filepage_sizeis 64, soKc/Kpremain page tensors instead of[L, D]token matrices, and the subsequent decode matmuls no longer implement the declared operator. Please fix the upstream template for multi-token pages and regenerate this example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line 123, The reference _mla_paged_decode_reference assumes one-token pages but for page_size=64 you must flatten the per-page token dimension after selecting pages: when building Kc/Kp from Kc_all/Kp_all (currently from ckv_cache.squeeze(1)/kpe_cache.squeeze(1)) do Kc_all[tok_idx] and Kp_all[tok_idx] then reshape/flatten the result to [L, head_dim_ckv] and [L, head_dim_kpe] respectively (e.g., .reshape(-1, head_dim_ckv) / .reshape(-1, head_dim_kpe]) before computing logits and softmax) so the decode matmuls use token-level matrices; update _mla_paged_decode_reference to flatten selected pages accordingly and regenerate the example.flashinfer/fi_trace.py-273-280 (1)
273-280:⚠️ Potential issue | 🟠 MajorThe public helper never actually falls back to the legacy registry.
This module keeps
_REGISTRY,register_fi_trace(), andbuild_fi_trace_fn()for backwards compatibility, butfi_trace()only checksactual_func.fi_trace. Any legacy caller that registered a spec by qualname will still hit theNo fi_trace spec is registeredpath.Possible fix
actual_func = getattr(func_or_method, "__func__", func_or_method) trace_fn = getattr(actual_func, "fi_trace", None) if trace_fn is None: - qualname = getattr(actual_func, "__qualname__", repr(actual_func)) - raise ValueError( - f"No fi_trace spec is registered for '{qualname}'. " - "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace." - ) + qualname = getattr(actual_func, "__qualname__", None) + spec = _REGISTRY.get(qualname) if qualname is not None else None + if spec is not None: + trace_fn = build_fi_trace_fn(spec) + else: + qualname = qualname or repr(actual_func) + raise ValueError( + f"No fi_trace spec is registered for '{qualname}'. " + "Only `@flashinfer_api`(trace=...)-decorated functions support fi_trace." + ) return trace_fn(save_dir=save_dir, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fi_trace.py` around lines 273 - 280, The public helper fi_trace currently only checks the bound attribute actual_func.fi_trace and never looks up the legacy registry, so entries registered via register_fi_trace/_REGISTRY or built via build_fi_trace_fn are ignored; update the code after obtaining qualname to fall back to the legacy registry by looking up _REGISTRY[qualname] or calling build_fi_trace_fn(qualname) (using the same qualname computed from actual_func.__qualname__ or repr(actual_func)) and use that trace_fn when present before raising the ValueError so legacy-registered specs are honored.tests/test_fi_trace.py-357-362 (1)
357-362:⚠️ Potential issue | 🟠 MajorThese use-case tests allocate model-sized tensors even though
fi_traceonly inspects metadata.The
num_pages=8192decode case materializes about 512 MiB of KV cache, and the MLA example adds another ~288 MiB, just to read.shapeand.dtype. That is likely to slow or OOM CI without adding coverage. Please shrink these fixtures or move the model-scale examples out of the unit suite.Also applies to: 418-424
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_fi_trace.py` around lines 357 - 362, The test allocates model-sized tensors (num_pages, page_size, q, k_cache, v_cache) even though fi_trace only reads metadata; reduce memory by shrinking num_pages and page_size to small values (e.g., single- or double-digit sizes) or replace large concrete tensors with lightweight stand-ins (small shaped tensors or meta-device tensors) in the test vectors q, k_cache, v_cache used by test_fi_trace functions; apply the same change to the other occurrence around lines 418-424 to avoid CI OOMs while preserving the shape/dtype intent for fi_trace.flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json-170-170 (1)
170-170:⚠️ Potential issue | 🟠 Major
final_statenever reflects the updates computed in the loop.
state_HVKis mutated for each token, but the function returnsinitial_state.clone()without writing any updated state back into it. The example therefore emits the original state pool while the schema saysfinal_stateis the updated recurrent state.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170, The function returns initial_state.clone() even though state_HVK is updated per token; fix by writing the updated per-pool state_HVK back into the state pool using the same indexing flow you used to read it: after finishing the token loop for a batch item (or whenever you update state_HVK), assign final_state[state_idx] = state_HVK.transpose(-1, -2) (or update initial_state in-place) so that final_state (returned) contains the mutated states; ensure you use initial_state_indices/state_idx to map back and preserve dtype/device the same way intermediate_states_buffer is handled.flashinfer/trace/templates/gemm.py-111-215 (1)
111-215:⚠️ Potential issue | 🟠 MajorThe non-BF16 templates emit shapes with undefined or mismatched axes.
mm_fp8_traceusesK_div_block_size/block_size,mm_mxfp8_traceusesK_div_32, andmm_fp4_traceusesK_div_block_size/N_div_block_size, but none of those derived dimensions are declared inaxesor tied back toK/Nwith constraints.mm_fp4_tracealso labels packed uint8 operands as logical[M, K]and[K, N], so the discovered axis values will be off on the packed dimension. The resulting JSON is not self-contained for these ops.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 111 - 215, The templates mm_fp8_trace, mm_mxfp8_trace, and mm_fp4_trace declare derived dimensions (K_div_block_size, block_size, K_div_32, N_div_block_size) in their Tensor shapes but never define them in axes or relate them back to K/N; also mm_fp4 inputs are described as logical [M,K]/[K,N] while the stored packed uint8 layout changes the packed dimension. Fix by adding explicit axes entries for each derived dimension in the axes dict (e.g., "block_size", "K_div_block_size", "K_div_32", "N_div_block_size" or a packed axis like "K_packed") and document the arithmetic relationship (K_div_block_size = K // block_size, K_div_32 = K // 32, N_div_block_size = N // block_size or K_packed = packed_length_of(K) for fp4); then update the corresponding Tensor shapes in mm_fp8_trace, mm_mxfp8_trace, and mm_fp4_trace to reference those axes (and adjust A/B shapes for fp4 to use the packed axis instead of logical K) so the JSON is self-contained and axis relationships are explicit.flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json-112-112 (1)
112-112:⚠️ Potential issue | 🟠 MajorThe paged decode reference is indexing page IDs as if they were token IDs.
kv_indicesis documented here as a page-ID array, butk_cache.reshape(-1, ...)/v_cache.reshape(-1, ...)followed by...[token_ids]only selects one row per page and drops the remainingpage_size - 1tokens. If this reference is used for verification, any multi-token page will compare against the wrong attention result. Based on learnings, when native paged KV layout is used, page indices are not supposed to be flattened into token indices.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json` at line 112, The reference implementation in _gqa_paged_decode_reference incorrectly treats kv_indices as token IDs by indexing into k_flat/v_flat (created by k_cache.reshape(-1,...)), which drops tokens within multi-token pages; instead, treat kv_indices as page IDs: extract pages via k_cache[pages] and v_cache[pages] (where pages = kv_indices[page_start:page_end].to(torch.long)), then combine the page_size dimension (e.g., .reshape(-1, num_kv_heads, head_dim)) so all tokens in each page are included before computing logits/attention; update k_b, v_b, and any downstream uses to reflect this page->token expansion while keeping q_b and gqa_ratio logic unchanged.flashinfer/trace/templates/gemm.py-22-85 (1)
22-85:⚠️ Potential issue | 🟠 MajorFix B tensor handling in quantized GEMM references and resolve undefined symbolic dimensions.
The quantized GEMM references have multiple critical issues:
Matrix multiply semantics: All references multiply with
B.Tdespite describing B with physical shape[K, N]. This is mathematically incorrect:[M, K] @ [K, N].T = [M, K] @ [N, K]has mismatched inner dimensions. The references should either remove the transpose or update schemas to describe B as[N, K].FP8 block layout:
_mm_fp8_reference()reshapes[K//block_size, N, block_size]directly to[K, N]without permuting first. TRT-LLM block layout requires permutation before reshape to reconstruct the original matrix correctly (i.e.,.reshape(K_div_bs, block_size, N).permute(1, 0, 2).reshape(K, N)).FP4 decoding:
_unpack_fp4()extracts raw nibble values (0–15) via bitwise masking and casts to float32 without decoding the e2m1fn format. The reference cannot serve as a correctness oracle without proper FP4 value lookup or conversion.Undefined symbolic axes: The FP8, MXFP8, and FP4 templates reference symbolic dimensions (
K_div_block_size,K_div_32,N_div_block_size,block_size) not declared in theiraxesdictionaries, preventing proper schema validation.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 22 - 85, The GEMM refs incorrectly transpose B and misuse block layouts: update _mm_reference, _mm_fp8_reference, _mm_mxfp8_reference, and _mm_fp4_reference so matmul uses A @ B (not A @ B.T) if B is intended as [K, N], or alternatively document/reshape B to [N, K] consistently; in _mm_fp8_reference before reshaping apply the TRT-LLM permutation (reshape to [K_div_bs, block_size, N] then permute(1,0,2) then reshape to [K, N]) instead of direct reshape; replace _unpack_fp4 with proper e2m1fn decoding (use a lookup/decode table to map 4-bit nibble values to float32) rather than raw nibble casts so FP4 semantics are correct; and add the missing symbolic axis declarations for K_div_bs/K_div_32, N_div_block_size and block_size in the template axes metadata so schema validation can resolve those symbols (referencing the functions _mm_fp8_reference, _mm_mxfp8_reference, _mm_fp4_reference and helper _unpack_fp4 to locate the changes).flashinfer/trace/templates/attention.py-113-116 (1)
113-116:⚠️ Potential issue | 🟠 MajorAdd the grouped-query head constraints to the GQA templates.
The GQA references rely on
num_qo_heads // num_kv_headsbeing a valid grouping factor, but the schema currently accepts shapes wherenum_qo_heads < num_kv_headsor the ratio is non-integral. In those cases the reference either divides by zero or walkskv_hpast the last KV head. Please addnum_qo_heads >= num_kv_headsandnum_qo_heads % num_kv_heads == 0here, and mirror the same invariant ingqa_paged_prefill_traceandgqa_ragged_prefill_trace.Possible fix
constraints=[ "len_indptr == batch_size + 1", "num_kv_indices == kv_indptr[-1].item()", + "num_qo_heads >= num_kv_heads", + "num_qo_heads % num_kv_heads == 0", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 113 - 116, The schema for the GQA templates currently allows invalid head groupings; update the constraints list in the attention template (the constraints array where "len_indptr == batch_size + 1" and "num_kv_indices == kv_indptr[-1].item()" are defined) to also require "num_qo_heads >= num_kv_heads" and "num_qo_heads % num_kv_heads == 0", and apply the same two invariants to the corresponding constraint lists in gqa_paged_prefill_trace and gqa_ragged_prefill_trace so the grouped-query computation (which uses num_qo_heads // num_kv_heads and kv head indexing) never divides by zero or indexes past the last KV head.flashinfer/trace/templates/gdn.py-164-168 (1)
164-168:⚠️ Potential issue | 🟠 MajorEnforce
seq_len == 1in the decode template.
_gdn_decode_referencedepends onsqueeze(1)removing the time axis. Ifseq_lenis anything else, it starts repeating along the sequence dimension instead of the head dimension and the reference becomes invalid. The description already says decode is single-token, so make that a hard constraint.Possible fix
constraints=[ + "seq_len == 1", "num_v_heads >= num_q_heads", "num_v_heads % num_q_heads == 0", "num_k_heads == num_q_heads", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 164 - 168, Add a hard constraint enforcing single-token decoding by adding "seq_len == 1" to the decode template's constraints list so the template and its consumer _gdn_decode_reference (which relies on squeeze(1) removing the time axis) never run with seq_len > 1; update the constraints array (the one containing "num_v_heads >= num_q_heads", "num_v_heads % num_q_heads == 0", "num_k_heads == num_q_heads") to include "seq_len == 1".flashinfer/trace/templates/gdn.py-327-330 (1)
327-330:⚠️ Potential issue | 🟠 MajorPrefill is missing the GVA head-shape invariants used by the reference.
The prefill reference expands Q/K with
num_v_heads // num_q_headsandnum_v_heads // num_k_heads, so it needs the same head relationship guarantees as decode/MTP. Right now the schema accepts shapes that can truncate the repeat factor or produce an output whose head axis no longer matches the declarednum_v_heads.Possible fix
constraints=[ "len_cu_seqlens == num_seqs + 1", "total_seq_len == cu_seqlens[-1].item()", + "num_v_heads >= num_q_heads", + "num_v_heads % num_q_heads == 0", + "num_k_heads == num_q_heads", ],🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 327 - 330, The schema is missing invariants that guarantee the GVA head expansion used in prefill; add constraints to the same constraints list to require divisibility so expansions don't truncate heads: include "num_v_heads % num_q_heads == 0" and "num_v_heads % num_k_heads == 0" (referring to the symbols num_v_heads, num_q_heads, num_k_heads) so the prefill expansion of Q/K by num_v_heads // num_q_heads and num_v_heads // num_k_heads preserves the declared num_v_heads head axis.flashinfer/trace/templates/attention.py-42-43 (1)
42-43:⚠️ Potential issue | 🟠 MajorDon't treat page ids as flattened token ids.
After
reshape(-1, ...), indexing with rawkv_indicesonly fetches one flattened slot per page and ignores the otherpage_size - 1entries. That makes the paged GQA reference wrong for anypage_size > 1, and the same pattern repeats in_gqa_paged_prefill_reference,_mla_paged_decode_reference, and_mla_paged_prefill_reference. Either materialize full pages and trim the last page with explicit length metadata, or constrain these paged templates topage_size == 1.Possible direction
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) - v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) + # Materialize selected pages first; then flatten tokens within those pages. + # The last page still needs an explicit length input to trim padding correctly. ... - token_ids = kv_indices[page_start:page_end].to(torch.long) - k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim] - v_b = v_flat[token_ids] + page_ids = kv_indices[page_start:page_end].to(torch.long) + k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32) + v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)Also applies to: 51-53
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 42 - 43, The code flattens page slots with k_cache.reshape(-1, num_kv_heads, head_dim) (k_flat/v_flat) and then indexes with kv_indices, which treats page ids as single flattened token ids and therefore drops the other page_size-1 entries; update the paged templates (_gqa_paged_prefill_reference, _mla_paged_decode_reference, _mla_paged_prefill_reference) to either (A) materialize full page slices before flattening (i.e., expand/reshape to include page_size dimension, gather full pages using kv_indices, then trim the final partial page using explicit length metadata) or (B) enforce/validate page_size == 1 at the start of these functions and raise an error if otherwise; ensure all uses of k_flat, v_flat and kv_indices are adjusted accordingly so each page returns all its key/value slots rather than a single flattened slot.flashinfer/trace/templates/gdn.py-377-410 (1)
377-410:⚠️ Potential issue | 🟠 MajorWrite the updated slot back before returning
final_state.
state_HVKis updated for every token, but nothing persists it into the returned pool. As written,final_state = initial_state.clone()returns the original state unchanged, which contradicts the template contract and breaks stateful verification.Possible fix
output = torch.zeros( (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device ) + final_state = initial_state.clone().float() cache_intermediate = intermediate_states_buffer is not None for b_idx in range(B): state_idx = int(initial_state_indices[b_idx].item()) - state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V] + state_HVK = final_state[state_idx].transpose(-1, -2).clone() # [H,V,K] -> [H,K,V] for t in range(T): ... if cache_intermediate: intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K] + + final_state[state_idx] = state_HVK.transpose(-1, -2) - final_state = initial_state.clone() return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 377 - 410, Summary: final_state is returned unchanged because updated state_HVK is never written back to the pool. Fix: clone initial_state into final_state before the outer loop (or otherwise create a mutable final_state), and after processing each batch element (after the inner t loop where state_HVK holds the final per-head slot), write the updated slot back via final_state[state_idx] = state_HVK.transpose(-1, -2) so the returned final_state reflects the updates; reference symbols: initial_state, final_state, state_HVK, state_idx, initial_state_indices.
🟡 Minor comments (4)
flashinfer/trace/template.py-371-376 (1)
371-376:⚠️ Potential issue | 🟡 MinorSilent exception swallowing may hide bugs during axis extraction.
Catching bare
Exceptionand passing silently can mask unexpected errors (e.g.,TypeError,AttributeError) that indicate template misconfiguration or API misuse. Consider logging at debug level or being more specific about expected exceptions.🔧 Proposed fix to add debug logging
+import logging + +_logger = logging.getLogger(__name__) + # In fi_trace function: for axis_name, extractor in axis_extractors.items(): try: val = extractor(kwargs) if val is not None: axis_values[axis_name] = val - except Exception: - pass + except Exception as e: + _logger.debug("Failed to extract axis %r: %s", axis_name, e)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 371 - 376, The try/except around calling extractor(kwargs) silently swallows all exceptions (using bare except), which can hide bugs; update the except to capture the exception as e and emit a debug-level log including axis_name, extractor, and the exception/traceback (or narrow the except to expected errors like KeyError/IndexError/ValueError if applicable) before continuing, ensuring axis_values and axis_name remain unchanged on failure; reuse the module's logger instance (or create one if none exists) so the failure context is recorded for debugging.flashinfer/trace/templates/norm.py-56-88 (1)
56-88:⚠️ Potential issue | 🟡 MinorReference implementation return value mismatch with template outputs.
The
_fused_add_rmsnorm_referencefunction returns onlyy(single tensor), butfused_add_rmsnorm_tracedefines two outputs:outputandresidual. The reference should return both to match the template schema.🐛 Proposed fix to return both outputs
`@torch.no_grad`() def _fused_add_rmsnorm_reference(hidden_states, residual, weight): """Fused Add + RMSNorm. Epsilon is fixed at 1e-6.""" EPS = 1e-6 - x = hidden_states.to(torch.float32) + residual.to(torch.float32) + residual_updated = hidden_states + residual + x = residual_updated.to(torch.float32) inv_rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + EPS) y = (x * inv_rms) * weight.to(torch.float32) - return y.to(hidden_states.dtype) + return y.to(hidden_states.dtype), residual_updated🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/norm.py` around lines 56 - 88, The reference function _fused_add_rmsnorm_reference currently returns only y but the TraceTemplate fused_add_rmsnorm_trace declares two outputs ("output" and updated "residual"); update _fused_add_rmsnorm_reference to return a tuple (output, residual_out) where residual_out reflects the in-place semantics described (residual += hidden_states) — e.g., compute residual_out = residual.to(torch.float32) + hidden_states.to(torch.float32) (or perform an in-place add if appropriate), cast both y and residual_out back to the original hidden_states dtype, and return them in the same order as the template outputs.flashinfer/trace/example/__main__.py-1-1 (1)
1-1:⚠️ Potential issue | 🟡 MinorReplace wildcard import with explicit module reference.
Line 1 uses
from .example import *, which triggers Ruff F403 and obscures what is actually imported. Sinceexample.pydefines no__all__and is structured as a side-effect module (not an export container), usefrom . import exampleinstead to make the intent explicit.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/__main__.py` at line 1, Replace the wildcard import in __main__.py: remove "from .example import *" and import the module explicitly (use "from . import example") so the code references names via the example module; update any direct references that relied on the star-import to be prefixed with "example." to keep intent explicit and satisfy Ruff F403.flashinfer/trace/example/example.py-129-130 (1)
129-130:⚠️ Potential issue | 🟡 MinorAvoid silent
except Exception: passin the example runner.These blocks hide unexpected failures and can make the generated trace set look complete when it is not.
♻️ Suggested fix
-except Exception: - pass # Requires Blackwell (SM100+) +except Exception as e: + print(f"[skip] mm_mxfp8 example not run: {e}") # Requires Blackwell (SM100+) -except Exception: - pass # Requires Blackwell (SM100+) +except Exception as e: + print(f"[skip] mm_fp4 example not run: {e}") # Requires Blackwell (SM100+) -except Exception: - pass # May require specific GPU/TRT-LLM support +except Exception as e: + print(f"[skip] trtllm_fp8_block_scale_moe example not run: {e}")Also applies to: 140-141, 276-277
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/example/example.py` around lines 129 - 130, Replace each silent "except Exception: pass" in the example runner (the three occurrences that suppress errors for the Blackwell/SM100+ path) with targeted handling: catch only the expected exception (e.g., ImportError or ModuleNotFoundError when Blackwell is absent) or, if you must continue on error, log the full exception with logging.exception or traceback.print_exc including contextual information about which trace/step failed; do not swallow unexpected exceptions—re-raise them after logging so real failures are visible.
🧹 Nitpick comments (2)
flashinfer/trace/template.py (1)
474-474: Consider using spread operator for list construction.Per static analysis, using spread syntax is more idiomatic.
♻️ Suggested change
- all_tags = [f"fi_api:{fi_api}"] + template.tags + all_tags = [f"fi_api:{fi_api}", *template.tags]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` at line 474, Replace the list concatenation used to build all_tags with Python list unpacking for readability: instead of creating all_tags via [f"fi_api:{fi_api}"] + template.tags, construct it using the spread/unpacking form to include f"fi_api:{fi_api}" and all elements from template.tags (referencing variables all_tags, fi_api, and template.tags in template.py).flashinfer/trace/__init__.py (1)
23-25: Consider sorting__all__and reconsidering private export.Per static analysis,
__all__should be sorted. Additionally,_TRACE_DUMP_DIRhas a private naming convention (underscore prefix) but is exported publicly—consider renaming toTRACE_DUMP_DIRif it's meant for external use, or documenting why it's exposed.♻️ Suggested sorted `__all__`
-__all__ = ["TraceTemplate", "Var", "Const", "Tensor", "Scalar", "_TRACE_DUMP_DIR"] +__all__ = ["Const", "Scalar", "Tensor", "TraceTemplate", "Var", "_TRACE_DUMP_DIR"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/__init__.py` around lines 23 - 25, The __all__ list is unsorted and exposes a name with a leading underscore (_TRACE_DUMP_DIR) which conflicts with its private naming; update the __all__ declaration so entries are alphabetically sorted (e.g., Const, Scalar, Tensor, TraceTemplate, Var) and decide whether _TRACE_DUMP_DIR is meant to be public—if so rename it to TRACE_DUMP_DIR in template.py and here and export that, otherwise remove it from __all__ (or add a comment/docstring explaining why the underscored name is intentionally exported) so exports and naming are consistent; adjust imports/usage accordingly (TraceTemplate, Var, Const, Tensor, Scalar, and the chosen dump-dir symbol).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 834f4a1a-3013-4d83-80ed-7022baffd452
📥 Commits
Reviewing files that changed from the base of the PR and between 2ca0d38 and 845363657f51cfee87e869c42133fc79c35d78d7.
📒 Files selected for processing (49)
flashinfer/__init__.pyflashinfer/api_logging.pyflashinfer/attention.pyflashinfer/decode.pyflashinfer/fi_trace.pyflashinfer/fused_moe/core.pyflashinfer/gdn_decode.pyflashinfer/gdn_prefill.pyflashinfer/gemm/gemm_base.pyflashinfer/mla/_core.pyflashinfer/mla/cute_dsl/mla_decode.pyflashinfer/norm/__init__.pyflashinfer/prefill.pyflashinfer/sampling.pyflashinfer/trace/__init__.pyflashinfer/trace/example/__main__.pyflashinfer/trace/example/example.pyflashinfer/trace/example/fi_trace_out/fused_add_rmsnorm_h5120.jsonflashinfer/trace/example/fi_trace_out/gdn_decode_qk4_v8_d128.jsonflashinfer/trace/example/fi_trace_out/gdn_mtp_qk4_v8_d128.jsonflashinfer/trace/example/fi_trace_out/gemm_bf16_N256_K7168.jsonflashinfer/trace/example/fi_trace_out/gemm_bf16_N4096_K4096.jsonflashinfer/trace/example/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsonflashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.jsonflashinfer/trace/example/fi_trace_out/gemm_mxfp8_N4096_K4096.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsonflashinfer/trace/example/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsonflashinfer/trace/example/fi_trace_out/gqa_ragged_h32_kv8_d128.jsonflashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsonflashinfer/trace/example/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonflashinfer/trace/example/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsonflashinfer/trace/example/fi_trace_out/rmsnorm_h4096.jsonflashinfer/trace/example/fi_trace_out/rmsnorm_h7168.jsonflashinfer/trace/example/fi_trace_out/top_k_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_k_top_p_sampling_v151936.jsonflashinfer/trace/example/fi_trace_out/top_p_sampling_v128256.jsonflashinfer/trace/example/fi_trace_out/top_p_sampling_v151936.jsonflashinfer/trace/template.pyflashinfer/trace/templates/__init__.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pyflashinfer/trace/templates/norm.pyflashinfer/trace/templates/sampling.pyflashinfer/trtllm_low_latency_gemm.pytests/test_fi_trace.py
| "dtype": "bfloat16" | ||
| } | ||
| }, | ||
| "reference": "def _mm_fp8_reference(A, B):\n \"\"\"Dequantize FP8 block-scale inputs and compute C = A @ B.T.\n\n B is in TRT-LLM block layout [K//block_size, N, block_size] and is\n reshaped to [K, N] before the matmul.\n \"\"\"\n K_div_bs, N, block_size = B.shape\n B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32)\n A_fp32 = A.to(torch.float32)\n return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)\n" |
There was a problem hiding this comment.
Fix the embedded reference matmul transpose.
At Line 50, B_fp32 is reshaped to [K, N], so A_fp32 @ B_fp32.T is invalid when K != N (here 7168 != 1536). The reference should multiply with B_fp32 (or reshape differently if transposed semantics are intended).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/example/fi_trace_out/gemm_fp8_N1536_K7168.json` at line 50,
The helper _mm_fp8_reference currently reshapes B into B_fp32 =
B.reshape(K_div_bs * block_size, N) (i.e., [K, N]) but then computes
torch.matmul(A_fp32, B_fp32.T), which is wrong when K != N; change the matmul to
use B_fp32 (torch.matmul(A_fp32, B_fp32)) so the multiplication matches the
reshaped [K, N] layout, or alternatively reshape B to [N, K] if you truly need
B_fp32.T semantics—fix the call in _mm_fp8_reference referencing B_fp32 and
A_fp32 accordingly.
| H = 7168 | ||
| I = 2048 | ||
| BLOCK = 128 |
There was a problem hiding this comment.
Hardcoded H/I makes reference execution shape-fragile.
_fp8_moe_run_experts is wired to H=7168 and I=2048, but template axes are shape-driven. This will fail or produce invalid behavior for other valid MoE shapes.
💡 Proposed fix
-H = 7168
-I = 2048
BLOCK = 128
@@
def _fp8_moe_run_experts(
@@
- T = hidden_states.shape[0]
+ T, H = hidden_states.shape
+ I = gemm2_weights.shape[2]
+ gemm1_out = gemm1_weights.shape[1]
+ if gemm1_out != 2 * I:
+ raise ValueError(
+ f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}"
+ )
@@
- A_scale_expanded = (
- A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
- )
+ A_scale_expanded = A_scale_TH.unsqueeze(-1).repeat(1, 1, BLOCK).reshape(T, H).contiguous()
@@
- X1, X2 = G1[:, :I], G1[:, I:]
+ X1, X2 = G1[:, :I], G1[:, I:]Also applies to: 48-58, 72-73, 86-87
🧰 Tools
🪛 Ruff (0.15.7)
[error] 26-26: Ambiguous variable name: I
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, Hardcoded constants
H=7168 and I=2048 (and uses of BLOCK) make _fp8_moe_run_experts and related
templates shape-fragile; change these to compute H and I from the template/axis
sizes at runtime and use a derived BLOCK (e.g., based on H/I or
template.block_size) instead of literal numbers. Locate the constants H, I,
BLOCK and replace them with expressions that read the relevant template axes or
tensor shapes (reference the template used by _fp8_moe_run_experts and other
occurrences) so all occurrences use dynamically computed sizes rather than
hardcoded values.
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (2)
tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json (1)
135-146: Consider specifying dtype forintermediate_states_buffer.The
dtypeis set to"unknown"while all other tensors have explicit dtypes. Since this buffer stores intermediate states similar toinitial_stateandfinal_state(bothfloat32), consider using"float32"for consistency—or add documentation explaining why the dtype is indeterminate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 135 - 146, The schema entry "intermediate_states_buffer" currently has dtype "unknown"; change it to a concrete dtype (e.g., "float32") to match the similar tensors "initial_state" and "final_state", or if dtype truly varies, add a clear description explaining why it's indeterminate and what types are allowed; update the "intermediate_states_buffer" dtype field and its description accordingly to ensure consistency and clarity.tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)
85-97: Use concrete integer dtypes for index tensors.
kv_indptrandkv_indicesare currently"dtype": "unknown". This weakens schema validation and downstream codegen/consumers. Prefer explicit integer types (typicallyint32orint64) for both.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` around lines 85 - 97, The schema uses "dtype": "unknown" for the index tensors kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit indices) so downstream validation and codegen can rely on a fixed integer type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/trace/example.py`:
- Around line 1-294: The file is a standalone script so pytest won't collect it;
convert it into a proper pytest test by moving the top-level side-effect code
into a single test function (e.g., def test_generate_fi_trace_jsons(tmp_path):)
while preserving the early environment setup (os.environ.setdefault(...) and
SAVE_DIR) before importing flashinfer, and use the tmp_path fixture to override
FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so outputs go to a test-isolated directory;
keep all calls to flashinfer functions and wrappers (e.g., flashinfer.rmsnorm,
flashinfer.fused_add_rmsnorm, flashinfer.top_k_sampling_from_probs,
flashinfer.mm_bf16, flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json`:
- Around line 120-124: The "scale" field is documented as having a default
(1/sqrt(head_size)) but isn't marked optional; update the JSON schema entry for
"scale" so consumers know it may be omitted—e.g., add an optional/nullable flag
or remove it from any "required" list and set "optional": true (or equivalent)
next to the "scale" property to reflect the default behavior.
- Line 148: The reference function _gdn_decode_reference uses math.sqrt and
F.softplus but the serialized source string has no imports, causing NameError
when exec/eval runs; fix by either injecting math and torch.nn.functional as F
into the exec/eval globals where _gdn_decode_reference is executed (ensure names
"math" and "F" are present) or prepend/import lines ("import math" and "import
torch.nn.functional as F") to the serialized reference string so
_gdn_decode_reference has the required symbols at runtime.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 159-168: The doc string for "final_state" references an undefined
parameter disable_state_update; either add a boolean input named
disable_state_update to the inputs section (e.g., description: "If true,
recurrent state updates are disabled and final_state remains unchanged") or
remove the mention "Unchanged if disable_state_update=True" from the
"final_state" description; update the "final_state" description or inputs
accordingly so the documentation no longer refers to an undefined symbol.
- Line 170: The reference function _gdn_mtp_reference updates per-batch states
in state_HVK but then returns final_state = initial_state.clone(), discarding
updates; fix by creating final_state = initial_state.clone() before the batch
loop and after processing each batch element (using state_idx =
int(initial_state_indices[b_idx].item())) write the updated state back with
final_state[state_idx] = state_HVK.transpose(-1, -2) (matching the stored
[H,V,K] layout); ensure types remain consistent (match .float()/.to dtype as
needed) and then return output, final_state.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 123: In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.
---
Nitpick comments:
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json`:
- Around line 135-146: The schema entry "intermediate_states_buffer" currently
has dtype "unknown"; change it to a concrete dtype (e.g., "float32") to match
the similar tensors "initial_state" and "final_state", or if dtype truly varies,
add a clear description explaining why it's indeterminate and what types are
allowed; update the "intermediate_states_buffer" dtype field and its description
accordingly to ensure consistency and clarity.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Around line 85-97: The schema uses "dtype": "unknown" for the index tensors
kv_indptr (shape "len_indptr") and kv_indices (shape "num_kv_indices"); change
both to a concrete integer dtype (prefer int32, or int64 if you need 64-bit
indices) so downstream validation and codegen can rely on a fixed integer
type—update the "dtype" entries for kv_indptr and kv_indices accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9d5f1191-ca90-4d41-be7d-6533b17213b1
📥 Commits
Reviewing files that changed from the base of the PR and between 845363657f51cfee87e869c42133fc79c35d78d7 and c5296b71eea86be83a63525183c5c31db0cf600a.
📒 Files selected for processing (29)
flashinfer/decode.pyflashinfer/fused_moe/core.pyflashinfer/gdn_decode.pyflashinfer/gemm/gemm_base.pyflashinfer/norm/__init__.pyflashinfer/prefill.pytests/trace/example.pytests/trace/fi_trace_out/fused_add_rmsnorm_h5120.jsontests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsontests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.jsontests/trace/fi_trace_out/gemm_bf16_N256_K7168.jsontests/trace/fi_trace_out/gemm_bf16_N4096_K4096.jsontests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsontests/trace/fi_trace_out/gemm_fp8_N1536_K7168.jsontests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsontests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/rmsnorm_h4096.jsontests/trace/fi_trace_out/rmsnorm_h7168.jsontests/trace/fi_trace_out/top_k_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v151936.jsontests/trace/fi_trace_out/top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_p_sampling_v151936.json
✅ Files skipped from review due to trivial changes (20)
- flashinfer/norm/init.py
- tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
- tests/trace/fi_trace_out/rmsnorm_h7168.json
- tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
- tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
- tests/trace/fi_trace_out/rmsnorm_h4096.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
- tests/trace/fi_trace_out/top_k_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
- tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
- tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
- tests/trace/fi_trace_out/top_p_sampling_v151936.json
- tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/top_p_sampling_v128256.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
- flashinfer/prefill.py
- tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
🚧 Files skipped from review as they are similar to previous changes (3)
- flashinfer/decode.py
- flashinfer/gdn_decode.py
- flashinfer/gemm/gemm_base.py
| q_r = torch.randn(256, num_qo, head_dim, dtype=torch.bfloat16, device=device) | ||
| k_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| v_r = torch.randn(512, num_kv, head_dim, dtype=torch.bfloat16, device=device) | ||
| rag.run(q_r, k_r, v_r) | ||
|
|
||
| # ── MLA paged decode (DeepSeek-V3 TP=8, h=16/ckv=512/kpe=64) ───────────────── | ||
| mla_b, mla_h, ckv, kpe = 128, 16, 512, 64 | ||
|
|
||
| for mla_ps, mla_np in ((64, 32), (1, 2048)): | ||
| total_mla = mla_b * mla_np | ||
| mla_qo_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device) | ||
| mla_kv_indptr = torch.arange(mla_b + 1, dtype=torch.int32, device=device) * mla_np | ||
| mla_kv_indices = torch.arange(total_mla, dtype=torch.int32, device=device) | ||
| mla_kv_len = torch.full((mla_b,), mla_np * mla_ps, dtype=torch.int32, device=device) | ||
|
|
||
| ws_mla = torch.empty(WORKSPACE, dtype=torch.uint8, device=device) | ||
| mla = BatchMLAPagedAttentionWrapper(ws_mla) | ||
| mla.plan( | ||
| mla_qo_indptr, mla_kv_indptr, mla_kv_indices, mla_kv_len, | ||
| mla_h, ckv, kpe, mla_ps, | ||
| causal=False, sm_scale=1.0 / (ckv ** 0.5), | ||
| q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16, | ||
| ) | ||
| q_nope = torch.randn(mla_b, mla_h, ckv, dtype=torch.bfloat16, device=device) | ||
| q_pe = torch.randn(mla_b, mla_h, kpe, dtype=torch.bfloat16, device=device) | ||
| ckv_cache = torch.randn(total_mla, mla_ps, ckv, dtype=torch.bfloat16, device=device) | ||
| kpe_cache = torch.randn(total_mla, mla_ps, kpe, dtype=torch.bfloat16, device=device) | ||
| mla.run(q_nope, q_pe, ckv_cache, kpe_cache) | ||
|
|
||
| # ── GDN decode (Qwen3-Next TP=4, qk=4/v=8/d=128) ──────────────────────────── | ||
| B, H, HV, K = 4, 4, 8, 128 | ||
| q = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device) | ||
| k = torch.randn(B, 1, H, K, dtype=torch.bfloat16, device=device) | ||
| v = torch.randn(B, 1, HV, K, dtype=torch.bfloat16, device=device) | ||
| state = torch.zeros(B, HV, K, K, dtype=torch.float32, device=device) | ||
| A_log = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| a = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device) | ||
| dt_bias = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| b_ = torch.zeros(B, 1, HV, dtype=torch.bfloat16, device=device) | ||
| flashinfer.gdn_decode.gated_delta_rule_decode(q, k, v, state, A_log, a, dt_bias, b_) | ||
|
|
||
| # ── GDN MTP (Qwen3-Next TP=4, spec_len=4) ──────────────────────────────────── | ||
| T_mtp, pool_size = 4, 8 | ||
| q_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device) | ||
| k_m = torch.randn(B, T_mtp, H, K, dtype=torch.bfloat16, device=device) | ||
| v_m = torch.randn(B, T_mtp, HV, K, dtype=torch.bfloat16, device=device) | ||
| init_state = torch.zeros(pool_size, HV, K, K, dtype=torch.float32, device=device) | ||
| init_idx = torch.arange(B, dtype=torch.int32, device=device) | ||
| A_log_m = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| a_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device) | ||
| dt_bias_m = torch.zeros(HV, dtype=torch.float32, device=device) | ||
| b_m = torch.zeros(B, T_mtp, HV, dtype=torch.bfloat16, device=device) | ||
| flashinfer.gdn_decode.gated_delta_rule_mtp( | ||
| q_m, k_m, v_m, init_state, init_idx, A_log_m, a_m, dt_bias_m, b_m | ||
| ) | ||
|
|
||
| # ── MoE FP8 (DeepSeek-V3 EP=8: 256 experts, 32 local, h=7168, i=2048, top_k=8) | ||
| try: | ||
| T_moe, H_moe, I_moe, E_tot, E_loc, BS = 128, 7168, 2048, 256, 32, 128 | ||
| routing_logits = torch.randn(T_moe, E_tot, dtype=torch.float32, device=device) | ||
| routing_bias = torch.zeros(E_tot, dtype=torch.bfloat16, device=device) | ||
| hs = torch.zeros(T_moe, H_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| hs_scale = torch.ones(H_moe // BS, T_moe, dtype=torch.float32, device=device) | ||
| w1 = torch.zeros(E_loc, 2 * I_moe, H_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| w1s = torch.ones(E_loc, (2 * I_moe) // BS, H_moe // BS, dtype=torch.float32, device=device) | ||
| w2 = torch.zeros(E_loc, H_moe, I_moe, dtype=torch.float8_e4m3fn, device=device) | ||
| w2s = torch.ones(E_loc, H_moe // BS, I_moe // BS, dtype=torch.float32, device=device) | ||
| flashinfer.fused_moe.trtllm_fp8_block_scale_moe( | ||
| routing_logits, routing_bias, | ||
| hs, hs_scale, | ||
| w1, w1s, | ||
| w2, w2s, | ||
| num_experts=E_tot, | ||
| top_k=8, | ||
| n_group=8, | ||
| topk_group=3, | ||
| intermediate_size=I_moe, | ||
| local_expert_offset=0, | ||
| local_num_experts=E_loc, | ||
| routed_scaling_factor=2.5, | ||
| ) | ||
| except Exception: | ||
| pass # May require specific GPU/TRT-LLM support | ||
|
|
||
| # ── Summary ─────────────────────────────────────────────────────────────────── | ||
| files = sorted(SAVE_DIR.glob("*.json")) | ||
| print(f"\nWrote {len(files)} definition files:\n") | ||
| for f in files: | ||
| defn = json.loads(f.read_text()) | ||
| print(f" {f.name}") | ||
| print(f" op_type : {defn['op_type']}") | ||
| print(f" fi_api : {next(t for t in defn['tags'] if t.startswith('fi_api:'))}") | ||
| const_axes = { | ||
| k: v["value"] | ||
| for k, v in defn["axes"].items() | ||
| if v["type"] == "const" and "value" in v | ||
| } | ||
| if const_axes: | ||
| print(f" axes : {const_axes}") | ||
| print() |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check if the file exists and get its basic info
find tests -name "example.py" -type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 90
🏁 Script executed:
# Check the structure of tests directory
find tests -type f -name "*.py" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 859
🏁 Script executed:
# Look for test_ prefixed functions in the current example.py
if [ -f "tests/trace/example.py" ]; then
grep -n "^def test_" tests/trace/example.py || echo "No test_ functions found"
fiRepository: flashinfer-ai/flashinfer
Length of output: 92
🏁 Script executed:
# Check if there's a main block and understand the current structure
if [ -f "tests/trace/example.py" ]; then
grep -n "if __name__" tests/trace/example.py
fiRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at other test files to understand the pytest pattern used
find tests -name "test_*.py" -type f | head -5Repository: flashinfer-ai/flashinfer
Length of output: 270
🏁 Script executed:
# Check the actual content and structure of example.py
tail -20 tests/trace/example.pyRepository: flashinfer-ai/flashinfer
Length of output: 741
🏁 Script executed:
# Look at a properly structured test file to understand the pattern
head -50 tests/attention/test_alibi.pyRepository: flashinfer-ai/flashinfer
Length of output: 1692
🏁 Script executed:
# Check if tests/trace directory has any other test files
ls -la tests/trace/Repository: flashinfer-ai/flashinfer
Length of output: 309
🏁 Script executed:
# Look for pytest configuration to understand test discovery
find . -name "pytest.ini" -o -name "pyproject.toml" -o -name "setup.cfg" | head -3 | xargs grep -l "testpaths\|python_files" 2>/dev/null || echo "No specific pytest config found for test patterns"Repository: flashinfer-ai/flashinfer
Length of output: 117
Refactor as a proper pytest test with test_ prefix to ensure CI coverage.
This file is a runnable script with side effects but lacks a test_ function, so pytest won't collect it during CI. Per coding guidelines for tests/**/*.py, test functions must be prefixed with test_ and structured in subdirectories matching kernel categories.
Suggested structure
+def run_trace_example(save_dir: Path) -> list[Path]:
+ # existing body here
+ return sorted(save_dir.glob("*.json"))
+
+def test_fi_trace_example_generates_defs(tmp_path, monkeypatch):
+ monkeypatch.setenv("FLASHINFER_TRACE_DUMP", "1")
+ monkeypatch.setenv("FLASHINFER_TRACE_DUMP_DIR", str(tmp_path))
+ files = run_trace_example(tmp_path)
+ assert files, "Expected fi_trace JSON files to be generated"
+
+if __name__ == "__main__":
+ run_trace_example(SAVE_DIR)🧰 Tools
🪛 Ruff (0.15.7)
[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 104-104: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 114-114: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 121-121: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[error] 129-130: try-except-pass detected, consider logging the exception
(S110)
[warning] 129-129: Do not catch blind exception: Exception
(BLE001)
[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[warning] 132-132: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
[error] 140-141: try-except-pass detected, consider logging the exception
(S110)
[warning] 140-140: Do not catch blind exception: Exception
(BLE001)
[error] 276-277: try-except-pass detected, consider logging the exception
(S110)
[warning] 276-276: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/example.py` around lines 1 - 294, The file is a standalone script
so pytest won't collect it; convert it into a proper pytest test by moving the
top-level side-effect code into a single test function (e.g., def
test_generate_fi_trace_jsons(tmp_path):) while preserving the early environment
setup (os.environ.setdefault(...) and SAVE_DIR) before importing flashinfer, and
use the tmp_path fixture to override FLASHINFER_TRACE_DUMP_DIR/SAVE_DIR so
outputs go to a test-isolated directory; keep all calls to flashinfer functions
and wrappers (e.g., flashinfer.rmsnorm, flashinfer.fused_add_rmsnorm,
flashinfer.top_k_sampling_from_probs, flashinfer.mm_bf16,
flashinfer.gdn_decode.gated_delta_rule_decode,
BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper,
flashinfer.fused_moe.trtllm_fp8_block_scale_moe, etc.) inside that test, and
remove or adapt prints/assert the expected JSON files exist via
SAVE_DIR.glob("*.json") to make the test assertions deterministic for CI.
| "scale": { | ||
| "shape": null, | ||
| "dtype": "float32", | ||
| "description": "Scale factor. Default is 1/sqrt(head_size)." | ||
| } |
There was a problem hiding this comment.
Mark scale as optional to match the declared default behavior.
Line 123 says a default is applied (1/sqrt(head_size)), but scale is not marked optional. This can make schema consumers treat it as required.
🛠️ Proposed fix
"scale": {
"shape": null,
"dtype": "float32",
+ "optional": true,
"description": "Scale factor. Default is 1/sqrt(head_size)."
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "scale": { | |
| "shape": null, | |
| "dtype": "float32", | |
| "description": "Scale factor. Default is 1/sqrt(head_size)." | |
| } | |
| "scale": { | |
| "shape": null, | |
| "dtype": "float32", | |
| "optional": true, | |
| "description": "Scale factor. Default is 1/sqrt(head_size)." | |
| } |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` around lines 120 - 124,
The "scale" field is documented as having a default (1/sqrt(head_size)) but
isn't marked optional; update the JSON schema entry for "scale" so consumers
know it may be omitted—e.g., add an optional/nullable flag or remove it from any
"required" list and set "optional": true (or equivalent) next to the "scale"
property to reflect the default behavior.
| "description": "Updated recurrent state in k-last layout [B, H, V, K]." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _gdn_decode_reference(q, k, v, state, A_log, a, dt_bias, b, scale):\n \"\"\"\n Gated Delta Net decode reference implementation (k-last layout).\n\n State layout: [B, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n Delta rule update:\n state_new = g * state_old + k^T @ (beta * v + (1-beta) * k @ state_old) - k^T @ (k @ state_old)\n output = scale * q @ state_new\n \"\"\"\n B, T, num_q_heads, K = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, V = v.shape\n num_heads = num_v_heads\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(K)\n\n x = a.float() + dt_bias.float() # [B, 1, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, 1, HV]\n beta = torch.sigmoid(b.float()) # [B, 1, HV]\n\n q_f32 = q.squeeze(1).float()\n k_f32 = k.squeeze(1).float()\n v_f32 = v.squeeze(1).float()\n g_f32 = g.squeeze(1).float()\n beta_f32 = beta.squeeze(1).float()\n\n if state is not None:\n state_f32 = state.float()\n else:\n state_f32 = torch.zeros(B, num_heads, V, K, dtype=torch.float32, device=device)\n\n q_exp = q_f32.repeat_interleave(num_v_heads // num_q_heads, dim=1)\n k_exp = k_f32.repeat_interleave(num_v_heads // num_k_heads, dim=1)\n\n new_state = torch.zeros_like(state_f32)\n output = torch.zeros(B, num_heads, V, dtype=torch.float32, device=device)\n\n for b_idx in range(B):\n for h_idx in range(num_heads):\n q_h = q_exp[b_idx, h_idx]\n k_h = k_exp[b_idx, h_idx]\n v_h = v_f32[b_idx, h_idx]\n h_state = state_f32[b_idx, h_idx].clone().transpose(-1, -2) # [V,K] -> [K,V]\n g_val = g_f32[b_idx, h_idx]\n beta_val = beta_f32[b_idx, h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, h_idx] = scale * (q_h @ h_state)\n new_state[b_idx, h_idx] = h_state.transpose(-1, -2) # [K,V] -> [V,K]\n\n output = output.unsqueeze(1).to(torch.bfloat16)\n return output, new_state\n" |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Find where trace "reference" snippets are parsed/executed.
rg -n -C3 --type=py 'reference|exec\(|compile\('
# 2) Verify whether execution globals/locals include math and torch.nn.functional as F.
rg -n -C3 --type=py 'import math|torch\.nn\.functional as F|globals\(|locals\('
# 3) Find other trace fixtures that use math./F. in embedded references
# to confirm the expected convention is consistent.
rg -n --type=json 'math\.|F\.softplus|F\.' testsRepository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
#!/bin/bash
# Find files that load or process trace JSON files
find tests -type f -name "*.py" | head -20
rg -l "trace.*json|gdn_decode" tests --type py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 998
🏁 Script executed:
#!/bin/bash
# Search for code that parses/loads trace JSON files
rg -n "load.*trace|json\.load|reference" tests --type py -A 3 | grep -E "trace|reference|exec|compile" | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3912
🏁 Script executed:
#!/bin/bash
# Look for test files in the trace directory
ls -la tests/trace*.py tests/*trace*.py 2>/dev/null || find tests -name "*trace*" -type fRepository: flashinfer-ai/flashinfer
Length of output: 163
🏁 Script executed:
#!/bin/bash
# Read the main trace test file to understand how reference is executed
wc -l tests/test_fi_trace.py
head -100 tests/test_fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 3745
🏁 Script executed:
#!/bin/bash
# Search for exec/compile patterns in trace test file
rg -n "exec|compile|reference|math|import F|import math" tests/test_fi_trace.py -A 5 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check the actual trace JSON file
head -30 tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsonRepository: flashinfer-ai/flashinfer
Length of output: 940
🏁 Script executed:
#!/bin/bash
# Search for code that loads/executes these trace JSON files
rg -rn "gdn_decode|fi_trace_out" tests --type pyRepository: flashinfer-ai/flashinfer
Length of output: 2731
🏁 Script executed:
#!/bin/bash
# Search more broadly for "reference" field execution in Python test files
rg -rn "reference.*exec|exec.*reference|\[.reference.\]" tests --type py -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1018
🏁 Script executed:
#!/bin/bash
# Search for any code that executes or uses the "reference" field from JSON traces
rg -rn "\.reference|reference.*exec|eval.*reference" tests --type py -A 3 -B 3Repository: flashinfer-ai/flashinfer
Length of output: 2062
🏁 Script executed:
#!/bin/bash
# Look at the gdn test file that's mentioned
wc -l tests/gdn/test_decode_delta_rule.py
head -50 tests/gdn/test_decode_delta_rule.pyRepository: flashinfer-ai/flashinfer
Length of output: 1537
🏁 Script executed:
#!/bin/bash
# Search for how reference implementations are called in the gdn tests
rg -n "_gdn_decode_reference|_reference" tests/gdn --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for any code that loads JSON traces and executes reference field
find tests -name "*.py" -type f -exec rg -l "reference" {} \; | head -5
rg -rn "loads\(|compile\(|exec\(" tests/trace --type pyRepository: flashinfer-ai/flashinfer
Length of output: 324
🏁 Script executed:
#!/bin/bash
# Check if there's a trace execution utility in the flashinfer source
find . -path ./tests -prune -o -name "*.py" -type f -print | xargs rg -l "fi_trace|trace.*reference" 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 284
🏁 Script executed:
#!/bin/bash
# Search more broadly for any eval/exec of code strings in Python files
rg -rn "eval\(|exec\(|compile\(" . --type py --iglob "*/flashinfer/*" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Examine the trace template and related files
wc -l flashinfer/trace/template.py flashinfer/fi_trace.py
head -80 flashinfer/trace/template.pyRepository: flashinfer-ai/flashinfer
Length of output: 3025
🏁 Script executed:
#!/bin/bash
# Search for reference field usage in trace modules
rg -n "reference" flashinfer/trace --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 30715
🏁 Script executed:
#!/bin/bash
# Check fi_trace.py for reference handling
rg -n "reference" flashinfer/fi_trace.py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look at the template.py file more carefully to see where reference is serialized
sed -n '480,510p' flashinfer/trace/template.pyRepository: flashinfer-ai/flashinfer
Length of output: 1460
🏁 Script executed:
#!/antml:thinking
The key finding is that `reference` is serialized as source code using `inspect.getsource()`. But I need to check if/where it's executed. Let me check if the trace system actually EXECUTES the reference field or just stores it.
</thinking>
# Now check the GDN reference function definition
cat flashinfer/trace/templates/gdn.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
#!/bin/bash
# Search for where the "reference" field from JSON is actually executed
rg -rn "json.*reference|loads.*reference|exec.*reference" . --type py --iglob "*/flashinfer/*" -A 5 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Check the actual gdn.py template to see the reference function definition
head -80 flashinfer/trace/templates/gdn.pyRepository: flashinfer-ai/flashinfer
Length of output: 2773
🏁 Script executed:
#!/bin/bash
# Look for any code that loads trace JSON and executes the reference
find . -name "*.py" -path "*/flashinfer/*" -o -path "*/tests/*" | xargs grep -l "reference.*exec\|exec.*reference" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 164
🏁 Script executed:
#!/bin/bash
# Check if there are any test files that load and run the reference field from JSON traces
find tests -name "*.py" -type f | xargs grep -l "\.json" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 204
🏁 Script executed:
#!/bin/bash
# Search for where reference strings might be executed from JSON
rg -rn "reference.*eval|reference.*exec|exec.*inspect|compile.*reference" tests flashinfer --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer
Length of output: 1200
🏁 Script executed:
#!/bin/bash
# Check test_artifacts.py to see if it loads/executes traces
wc -l tests/test_artifacts.py
head -100 tests/test_artifacts.pyRepository: flashinfer-ai/flashinfer
Length of output: 4681
🏁 Script executed:
#!/bin/bash
# Look for any test that loads the JSON traces and uses them
grep -r "fi_trace_out\|gdn_decode_qk4" tests --type py -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 185
🏁 Script executed:
#!/bin/bash
# Search for references to trace JSON files
rg "fi_trace_out|gdn_decode_qk4" tests -B 2 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1661
🏁 Script executed:
#!/bin/bash
# Look for any test files that actually load and execute the reference from JSON
rg "reference.*code|reference.*string|compile\(|exec\(" flashinfer/trace --type py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1775
🏁 Script executed:
#!/bin/bash
# Directly check the JSON file to confirm what the reference field contains
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
data = json.load(f)
if 'reference' in data:
ref = data['reference']
print("First 500 chars of reference field:")
print(ref[:500])
print("\n...\n")
print("Last 500 chars:")
print(ref[-500:])
# Check for imports
if 'import math' in ref:
print("\n✓ Contains: import math")
if 'import torch' in ref:
print("✓ Contains: import torch")
if 'torch.nn.functional as F' in ref or 'torch.nn.functional' in ref:
print("✓ Contains: torch.nn.functional reference")
if 'F.softplus' in ref:
print("✓ Uses: F.softplus")
if 'math.sqrt' in ref:
print("✓ Uses: math.sqrt")
else:
print("No 'reference' field in JSON")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 1164
🏁 Script executed:
#!/bin/bash
# Search for any code that actually executes the reference field from JSON
rg "reference.*exec|exec.*reference|compile.*reference|eval.*reference" flashinfer tests --type py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 1875
🏁 Script executed:
#!/bin/bash
# Check if the reference field is ONLY stored for documentation or if it's ever executed
rg "loads\(|load\(" flashinfer/trace --type py -B 3 -A 10 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Look for any test that loads and executes trace JSON references
find tests -name "*.py" -type f -exec grep -l "loads\|exec\|compile" {} \; | xargs grep -l "reference\|trace.*json" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 416
🏁 Script executed:
#!/bin/bash
# Check if reference is ever loaded and executed from the JSON
rg -rn "reference" flashinfer/trace/template.py -B 3 -A 10 | grep -A 10 "reference"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
#!/bin/bash
# Search more directly for any eval/exec in the entire flashinfer package
find flashinfer -name "*.py" -type f | xargs grep -n "eval\|exec\|compile" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 2853
🏁 Script executed:
#!/bin/bash
# Verify the actual imports in the gdn.py template file
head -30 flashinfer/trace/templates/gdn.py | grep -E "^import|^from"Repository: flashinfer-ai/flashinfer
Length of output: 189
🏁 Script executed:
#!/bin/bash
# Confirm the imports are NOT in the extracted reference string from JSON
python3 << 'EOF'
import json
with open('tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json') as f:
data = json.load(f)
if 'reference' in data:
ref = data['reference']
print("Checking for imports in the reference field:")
if 'import math' in ref:
print("✓ Contains: import math")
else:
print("✗ MISSING: import math")
if 'import torch' in ref:
print("✓ Contains: import torch")
else:
print("✗ MISSING: import torch")
if 'torch.nn.functional' in ref or 'import F' in ref:
print("✓ Contains: torch.nn.functional or F import")
else:
print("✗ MISSING: torch.nn.functional as F import")
print("\nFirst 300 chars of reference:")
print(ref[:300])
EOFRepository: flashinfer-ai/flashinfer
Length of output: 534
🏁 Script executed:
#!/bin/bash
# Check if there are any plans or code to execute the reference field in the future
rg -rn "reference.*exec|load.*reference|reference.*run" flashinfer --type py -B 3 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 50
The embedded reference code lacks required imports and will fail if executed.
Line 148's reference field contains math.sqrt and F.softplus but is serialized as a function source string without the required imports (import math and import torch.nn.functional as F). If this reference string is ever executed via exec() or eval(), it will raise NameError for unbound symbols.
Ensure that any trace execution context either:
- Injects
mathandtorch.nn.functional(aliased asF) into the execution globals, or - Includes the necessary import statements in the serialized reference string.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json` at line 148, The
reference function _gdn_decode_reference uses math.sqrt and F.softplus but the
serialized source string has no imports, causing NameError when exec/eval runs;
fix by either injecting math and torch.nn.functional as F into the exec/eval
globals where _gdn_decode_reference is executed (ensure names "math" and "F" are
present) or prepend/import lines ("import math" and "import torch.nn.functional
as F") to the serialized reference string so _gdn_decode_reference has the
required symbols at runtime.
| "final_state": { | ||
| "shape": [ | ||
| "pool_size", | ||
| "num_v_heads", | ||
| "head_size", | ||
| "head_size" | ||
| ], | ||
| "dtype": "float32", | ||
| "description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True." | ||
| } |
There was a problem hiding this comment.
Documentation references undefined parameter disable_state_update.
Line 167 states "Unchanged if disable_state_update=True" but disable_state_update is not defined in the inputs section. Either add this parameter to inputs if it's required, or remove the reference from the description.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` around lines 159 - 168,
The doc string for "final_state" references an undefined parameter
disable_state_update; either add a boolean input named disable_state_update to
the inputs section (e.g., description: "If true, recurrent state updates are
disabled and final_state remains unchanged") or remove the mention "Unchanged if
disable_state_update=True" from the "final_state" description; update the
"final_state" description or inputs accordingly so the documentation no longer
refers to an undefined symbol.
| "description": "Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _gdn_mtp_reference(\n q, k, v, initial_state, initial_state_indices, A_log, a, dt_bias, b, scale,\n intermediate_states_buffer=None,\n):\n \"\"\"\n Gated Delta Net MTP (Multi-Token Prediction) reference implementation.\n\n State layout: [pool_size, H, V, K] (k-last, K dimension at the end)\n\n Gate computation:\n g = exp(-exp(A_log) * softplus(a + dt_bias))\n beta = sigmoid(b)\n\n For each token t in sequence:\n state_new = g_t * state_old + k_t^T @ (beta_t * v_t + (1-beta_t) * k_t @ state_old) - k_t^T @ (k_t @ state_old)\n output_t = scale * q_t @ state_new\n state_old = state_new # Update for next token\n \"\"\"\n B, T, num_q_heads, head_size = q.shape\n _, _, num_k_heads, _ = k.shape\n _, _, num_v_heads, _ = v.shape\n device = q.device\n\n if scale is None or scale == 0.0:\n scale = 1.0 / math.sqrt(head_size)\n\n x = a.float() + dt_bias.float() # [B, T, HV]\n g = torch.exp(-torch.exp(A_log.float()) * F.softplus(x)) # [B, T, HV]\n beta = torch.sigmoid(b.float()) # [B, T, HV]\n\n q_exp = q.repeat_interleave(num_v_heads // num_q_heads, dim=2) # [B, T, HV, K]\n k_exp = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) # [B, T, HV, K]\n\n output = torch.zeros(\n (B, T, num_v_heads, head_size), dtype=torch.bfloat16, device=device\n )\n cache_intermediate = intermediate_states_buffer is not None\n\n for b_idx in range(B):\n state_idx = int(initial_state_indices[b_idx].item())\n state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]\n\n for t in range(T):\n q_HK = q_exp[b_idx, t].float() # [HV, K]\n k_HK = k_exp[b_idx, t].float() # [HV, K]\n v_HV = v[b_idx, t].float() # [HV, V]\n g_H = g[b_idx, t] # [HV]\n beta_H = beta[b_idx, t] # [HV]\n\n for h_idx in range(num_v_heads):\n q_h = q_HK[h_idx]\n k_h = k_HK[h_idx]\n v_h = v_HV[h_idx]\n h_state = state_HVK[h_idx]\n g_val = g_H[h_idx]\n beta_val = beta_H[h_idx]\n\n old_state = g_val * h_state\n old_v = k_h @ old_state\n new_v = beta_val * v_h + (1 - beta_val) * old_v\n state_remove = k_h.unsqueeze(1) @ old_v.unsqueeze(0)\n state_update = k_h.unsqueeze(1) @ new_v.unsqueeze(0)\n h_state = old_state - state_remove + state_update\n\n output[b_idx, t, h_idx] = (scale * (q_h @ h_state)).to(torch.bfloat16)\n state_HVK[h_idx] = h_state\n\n if cache_intermediate:\n intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]\n\n final_state = initial_state.clone()\n return output, final_state\n" |
There was a problem hiding this comment.
Reference implementation does not return the updated state.
The reference function computes state updates in state_HVK for each batch element, but at the end returns initial_state.clone() instead of the accumulated updated state:
final_state = initial_state.clone()
return output, final_stateThis means final_state will always equal the input initial_state, discarding all computed state updates. The correct behavior should write state_HVK.transpose(-1, -2) back to final_state[state_idx] after processing each batch.
🐛 Proposed fix
+ final_state = initial_state.clone()
for b_idx in range(B):
state_idx = int(initial_state_indices[b_idx].item())
state_HVK = initial_state[state_idx].clone().float().transpose(-1, -2) # [H,V,K] -> [H,K,V]
for t in range(T):
# ... state update logic ...
if cache_intermediate:
intermediate_states_buffer[state_idx, t] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]
- final_state = initial_state.clone()
+ # Write back updated state for this batch element
+ final_state[state_idx] = state_HVK.transpose(-1, -2) # [H,K,V] -> [H,V,K]
+
return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.json` at line 170, The reference
function _gdn_mtp_reference updates per-batch states in state_HVK but then
returns final_state = initial_state.clone(), discarding updates; fix by creating
final_state = initial_state.clone() before the batch loop and after processing
each batch element (using state_idx = int(initial_state_indices[b_idx].item()))
write the updated state back with final_state[state_idx] =
state_HVK.transpose(-1, -2) (matching the stored [H,V,K] layout); ensure types
remain consistent (match .float()/.to dtype as needed) and then return output,
final_state.
| "description": "The 2-based log-sum-exp of attention logits." | ||
| } | ||
| }, | ||
| "reference": "@torch.no_grad()\ndef _mla_paged_decode_reference(\n q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale\n):\n batch_size, num_qo_heads, head_dim_ckv = q_nope.shape\n len_indptr = kv_indptr.shape[0]\n\n Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv]\n Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe]\n\n output = torch.zeros(\n (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device\n )\n lse = torch.full(\n (batch_size, num_qo_heads), -float(\"inf\"), dtype=torch.float32, device=q_nope.device\n )\n\n for b in range(batch_size):\n page_beg = int(kv_indptr[b].item())\n page_end = int(kv_indptr[b + 1].item())\n if page_beg >= page_end:\n output[b].zero_()\n continue\n tok_idx = kv_indices[page_beg:page_end].to(torch.long)\n Kc = Kc_all[tok_idx] # [L, head_dim_ckv]\n Kp = Kp_all[tok_idx] # [L, head_dim_kpe]\n qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv]\n qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe]\n logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L]\n lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0)\n output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16)\n\n return output, lse\n" |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
For PyTorch 2.x: if qnis shape[H, D]andKcis shape[L, P, D], what are the rules for qn @ Kc.T? Also, what does .T do for tensors with more than 2 dimensions?
💡 Result:
For PyTorch 2.x, the @ operator uses torch.matmul rules for matrix multiplication. For qn (shape [H, D]) @ Kc.T where Kc is [L, P, D]: - tensor.T (or .t) on a tensor with more than 2 dimensions reverses all dimensions: for Kc [L, P, D] becomes [D, P, L]. This behavior is deprecated for >2D and will error in future releases; use .mT to transpose only the last two dims or .transpose(-2, -1). - However, matmul treats the 2D qn [H, D] and 3D Kc.T [D, P, L] as batched: prepends 1 to qn making it [...=1, H, D], broadcasts batch dims [1] with [D] (invalid, sizes differ), so raises RuntimeError: "matmul: Input operand 0 does not have enough dimensions" or shape mismatch. The inner matrix dims D==D match, but batch dims don't broadcast. To compute batched attention scores (intended [H, L*P] or reshaped), use qn @ Kc.transpose(-2, -1) which is [H, D] @ [L, P, D] -> [L, P, D]^T=[L, D, P], batched matmul gives [H, L, P] after broadcasting the missing batch dim on qn. .mT on Kc [L,P,D] gives [L, D, P] (transposes last two), then qn @ Kc.mT gives [H, L, P].
Citations:
- 1: https://docs.pytorch.org/docs/stable/generated/torch.matmul.html
- 2: https://dev.to/hyperkai/matrix-and-element-wise-multiplication-in-pytorch-25le
- 3: https://runebook.dev/en/docs/pytorch/generated/torch.mm
- 4: https://docs.pytorch.org/docs/1.10/generated/torch.matmul.html
- 5: https://pytorch.org/docs/stable/tensors.html
- 6: https://docs.pytorch.org/docs/1.10/generated/torch.t.html
- 7: https://www.codegenes.net/blog/pytorch-t/
- 8: https://docs.pytorch.org/docs/2.0/tensors.html
- 9: [ux] Proposal to have t() === transpose(-1, -2), since batches are very frequent pytorch/pytorch#51280
🏁 Script executed:
# First, let's look at the JSON file structure and find the input shapes
fd -type f -name "mla_paged_decode_h16_ckv512_kpe64_ps64.json" | head -1 | xargs -I {} sh -c '
# Get file size to decide how to read it
wc -l {} | awk "{print \$1}"
'Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Find the JSON file using git ls-files
git ls-files | grep "mla_paged_decode_h16_ckv512_kpe64_ps64.json"Repository: flashinfer-ai/flashinfer
Length of output: 136
🏁 Script executed:
# Get the line count and examine the JSON structure
wc -l tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonRepository: flashinfer-ai/flashinfer
Length of output: 140
🏁 Script executed:
# Read the entire JSON file to understand input shapes
cat -n tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsonRepository: flashinfer-ai/flashinfer
Length of output: 5318
Reference function squeeze(1) is incorrect for paged layout.
On Line 123, squeeze(1) cannot remove the page_size dimension when it equals 64. The inputs ckv_cache and kpe_cache have shape [num_pages, page_size, head_dim_*], so squeeze(1) leaves them 3D. The subsequent indexing Kc_all[tok_idx] produces [L, 64, head_dim_ckv] instead of the intended 2D [L, head_dim_ckv], causing the matmul qn @ Kc.T to fail.
Replace squeeze(1) with reshape(num_pages * page_size, head_dim_ckv) to flatten pages into a token-major 2D tensor, or adjust the indexing logic if indices are page-level rather than token-level.
Additionally, specify explicit dtypes for kv_indptr and kv_indices (currently "unknown") to reduce schema ambiguity.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line
123, In _mla_paged_decode_reference the use of ckv_cache.squeeze(1) and
kpe_cache.squeeze(1) is wrong for paged tensors (shape [num_pages, page_size,
head_dim_*]) and leaves a 3D tensor so Kc_all[tok_idx] yields [L, page_size,
head_dim]; replace squeeze(1) with a flattening reshape (e.g. reshape(num_pages
* page_size, head_dim_ckv) / reshape(..., head_dim_kpe) or view(-1, head_dim_*))
so Kc_all and Kp_all become 2D token-major tensors before indexing, and ensure
kv_indptr and kv_indices are cast to explicit integer dtype (torch.long/int64)
before use to remove schema ambiguity.
There was a problem hiding this comment.
Actionable comments posted: 12
♻️ Duplicate comments (1)
flashinfer/trace/templates/moe.py (1)
25-27:⚠️ Potential issue | 🟠 MajorThe MoE reference is still hard-wired to one hidden/intermediate shape.
H=7168andI=2048leak into the scale expansion, output allocation, andG1split, so any traced MoE with different shapes will either reshape incorrectly or slice the expert output wrong.Suggested fix
-H = 7168 -I = 2048 BLOCK = 128 @@ - T = hidden_states.shape[0] + T, H = hidden_states.shape + I = gemm2_weights.shape[2] + gemm1_out = gemm1_weights.shape[1] + if gemm1_out != 2 * I: + raise ValueError( + f"Invalid gemm1_out_size={gemm1_out}, expected 2 * intermediate_size={2 * I}" + )Also applies to: 53-57, 72-88
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The code hard-codes H, I, and BLOCK which leak into scale expansion, output allocation, and the G1 split causing incorrect reshapes/slices for other MoE shapes; replace these constants with dynamic values derived from the model/tensor shapes (e.g., infer hidden_size and intermediate_size from the input/weight tensors or pass them as parameters), update any uses in scale expansion, output allocation, and the G1 split logic (references: H, I, BLOCK, and the G1 split/expert output slicing code in moe.py) to compute sizes at runtime and use those computed sizes for reshape, split and slice operations so traced models with different H/I/BLOCK work correctly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/trace/templates/attention.py`:
- Line 357: The variables len_indptr and page_size are assigned but never used;
to fix the ruff F841 error, either remove the unused assignments or rename them
to _len_indptr and _page_size (or prefix with a single underscore) where they
are set (e.g., the len_indptr = kv_indptr.shape[0] assignment and the other
page_size assignment in the same module), and apply the same change to the
second occurrence that also triggers the warning so the linter no longer reports
unused locals.
- Around line 144-157: The prefill path incorrectly treats kv_indices as token
indices by indexing k_flat/v_flat directly with page_ids (kv_indices) and
computing num_kv_tokens from page_ids.shape[0]; instead expand the selected
pages to token-level rows first: use kv_indices[kv_start:kv_end] to select page
rows from the full per-page KV buffer (not the flattened token axis), then
concatenate or expand those page rows into token-level k_b and v_b and compute
num_kv_tokens from the resulting expanded KV token rows; update usages around
k_flat, v_flat, page_ids, k_b, v_b and num_kv_tokens so page->token expansion
happens before indexing the flattened token axis.
- Around line 42-53: kv_indices currently represents page IDs, so indexing
k_flat/v_flat directly with kv_indices selects wrong rows when page_size > 1;
instead, first gather the pages from the original k_cache and v_cache using
kv_indices (use kv_indices to index the page dimension of k_cache/v_cache to
produce per-token page slices), then flatten or reshape the gathered per-page
tensors into token-level rows and proceed (so create k_b/v_b by gathering pages
via kv_indices from k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim]
before using them as k_b and v_b); update all uses of k_flat/v_flat and
token_ids accordingly and ensure kv_indptr logic still slices the kv_indices by
token count, not flattened token offsets.
- Around line 359-383: The reference implementations _mla_paged_decode_reference
and _mla_paged_prefill_reference assume page_size==1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1); instead update these functions to
flatten the page and token dimensions so arbitrary page_size works (e.g.,
replace squeeze(1) with a reshape/flatten to (-1, head_dim_ckv) for Kc_all and
(-1, head_dim_kpe) for Kp_all or use flatten(0,1)), ensuring subsequent indexing
via kv_indices still selects the correct token rows; alternatively, if you
prefer to keep the current code, enforce page_size==1 in the TraceTemplate
schema, but do not leave squeeze(1) as-is.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 153-157: The Tensor schema for the "output" entries in
flashinfer/trace/templates/gdn.py currently uses dtype_from="q" but the
implementation always casts outputs to torch.bfloat16; update the schema to
reflect the real emitted dtype by replacing dtype_from="q" with dtype="bfloat16"
for the "output" Tensor declarations (the entries named "output" in the
template), or alternatively model the runtime control explicitly if outputs can
vary; make the same change for the other "output" Tensor occurrences mentioned
so the trace metadata matches the torch.bfloat16 casts in the code.
- Around line 382-415: The function mutates per-example head states in state_HVK
but returns final_state built from the unchanged initial_state; fix by writing
the updated state_HVK back into final_state before returning. After the outer
loops (or just before return), clone initial_state into final_state as done now
and then for each b_idx set final_state[state_idx] = state_HVK.transpose(-1, -2)
(or assign the corresponding typed/ device-matched tensor) so the updated
[H,V,K] state for the sample index (state_idx derived from
initial_state_indices[b_idx]) is committed; ensure dtype/device matches
initial_state when assigning.
- Around line 205-206: gdn_prefill_trace currently expands q and k with
repeat_interleave using num_v_heads // num_q_heads and num_v_heads //
num_k_heads but does not validate the required head-ratio constraints; add
explicit checks in gdn_prefill_trace to assert num_v_heads >= num_q_heads and
num_v_heads % num_q_heads == 0 and also assert num_k_heads == num_q_heads (or
otherwise enforce the same constraints used by decode/MTP), and apply the same
fixes to the other expansion site (the block around the q/k/v repeat_interleave
at the later occurrence). Ensure the assertions raise clear errors mentioning
num_v_heads, num_q_heads, and num_k_heads so invalid head layouts are rejected
before repeat_interleave is called.
In `@flashinfer/trace/templates/gemm.py`:
- Around line 57-78: The template misdeclares packed uint8 inputs as logical FP4
shapes causing fi_trace to infer wrong K/N; update the public trace signatures
(or add a pre-trace extractor) so the runtime sees packed dimensions: treat A
and B as [M, K_packed] and [K_packed, N_packed] (or expose an extractor that
maps packed -> logical by doubling the last axis) and propagate corrected
logical axes into mm_fp4_trace before calling _mm_fp4_reference/_unpack_fp4;
apply the same change to the other occurrence around the second block (the
191-200 region) and ensure a_descale/b_descale shape metadata matches the
packed-block layout.
- Around line 22-35: The reference GEMM helpers currently transpose B (using .T)
even though B is modeled as the physical [K, N] tensor; update _mm_reference and
_mm_fp8_reference (and the other similar reference helpers in the file) to
multiply A by B directly (remove the .T and B_fp32.T), keeping the same dtype
conversions and return types (e.g., _mm_fp8_reference should still dequantize to
float32, matmul, then cast to bfloat16), and update any docstrings/comments that
incorrectly describe B as needing transpose.
In `@flashinfer/trace/templates/moe.py`:
- Around line 577-598: The direct attribute assignment
trtllm_fp8_block_scale_moe_trace_dispatch.templates causes mypy attr-defined
errors; replace that assignment with a setattr call to attach the templates list
at runtime (e.g., use setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value
(list of _MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 193-199: The test E2E generator currently assigns 0 for int32
scalars in the loop over template.inputs which can create impossible values
(e.g., block_size=0); update the assignment in the loop that inspects
isinstance(descriptor, Scalar) and uses _resolved_param(json_key, descriptor) so
that int32 defaults are positive (e.g., 1 or another small positive) and
preferably support per-parameter overrides for constrained scalars before
populating kwargs; ensure any change keeps optional descriptors skipped and
preserves the dtype branch for non-int32 floats, so assert_fi_trace_complete()
validates realistic traces.
---
Duplicate comments:
In `@flashinfer/trace/templates/moe.py`:
- Around line 25-27: The code hard-codes H, I, and BLOCK which leak into scale
expansion, output allocation, and the G1 split causing incorrect reshapes/slices
for other MoE shapes; replace these constants with dynamic values derived from
the model/tensor shapes (e.g., infer hidden_size and intermediate_size from the
input/weight tensors or pass them as parameters), update any uses in scale
expansion, output allocation, and the G1 split logic (references: H, I, BLOCK,
and the G1 split/expert output slicing code in moe.py) to compute sizes at
runtime and use those computed sizes for reshape, split and slice operations so
traced models with different H/I/BLOCK work correctly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 8f771e55-82cf-425e-9083-fba1ef3390e8
📥 Commits
Reviewing files that changed from the base of the PR and between c5296b71eea86be83a63525183c5c31db0cf600a and f7e2129265f21b39f8b8f460ab9cb59648c88322.
📒 Files selected for processing (7)
.claude/skills/add-cuda-kernel/SKILL.mdflashinfer/api_logging.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pytests/trace/test_fi_trace_template_consistency.py
| k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
| v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
|
|
||
| for b in range(batch_size): | ||
| page_start = int(kv_indptr[b].item()) | ||
| page_end = int(kv_indptr[b + 1].item()) | ||
| if page_start >= page_end: | ||
| output[b].zero_() | ||
| continue | ||
| token_ids = kv_indices[page_start:page_end].to(torch.long) | ||
| k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim] | ||
| v_b = v_flat[token_ids] |
There was a problem hiding this comment.
kv_indices are page IDs, but decode reference indexes flattened tokens.
This reference is incorrect when page_size > 1: indexing k_flat/v_flat with page IDs selects wrong rows. Use page gather first, then flatten within selected pages.
Proposed fix
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
- v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
- token_ids = kv_indices[page_start:page_end].to(torch.long)
- k_b = k_flat[token_ids] # [T, num_kv_heads, head_dim]
- v_b = v_flat[token_ids]
+ page_ids = kv_indices[page_start:page_end].to(torch.long)
+ k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 42 - 53, kv_indices
currently represents page IDs, so indexing k_flat/v_flat directly with
kv_indices selects wrong rows when page_size > 1; instead, first gather the
pages from the original k_cache and v_cache using kv_indices (use kv_indices to
index the page dimension of k_cache/v_cache to produce per-token page slices),
then flatten or reshape the gathered per-page tensors into token-level rows and
proceed (so create k_b/v_b by gathering pages via kv_indices from
k_cache/v_cache, then reshape to [T, num_kv_heads, head_dim] before using them
as k_b and v_b); update all uses of k_flat/v_flat and token_ids accordingly and
ensure kv_indptr logic still slices the kv_indices by token count, not flattened
token offsets.
| k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
| v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32) | ||
|
|
||
| for b in range(len_indptr - 1): | ||
| q_start = int(qo_indptr[b].item()) | ||
| q_end = int(qo_indptr[b + 1].item()) | ||
| kv_start = int(kv_indptr[b].item()) | ||
| kv_end = int(kv_indptr[b + 1].item()) | ||
| if q_start >= q_end or kv_start >= kv_end: | ||
| continue | ||
| page_ids = kv_indices[kv_start:kv_end].to(torch.long) | ||
| k_b = k_flat[page_ids] | ||
| v_b = v_flat[page_ids] | ||
| num_kv_tokens = page_ids.shape[0] |
There was a problem hiding this comment.
Prefill reference has the same page-id/token-id mismatch.
kv_indices are documented as page IDs, but this path indexes a flattened token axis directly. Expand selected pages first, then derive num_kv_tokens from expanded KV rows.
Proposed fix
- k_flat = k_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
- v_flat = v_cache.reshape(-1, num_kv_heads, head_dim).to(torch.float32)
@@
- page_ids = kv_indices[kv_start:kv_end].to(torch.long)
- k_b = k_flat[page_ids]
- v_b = v_flat[page_ids]
- num_kv_tokens = page_ids.shape[0]
+ page_ids = kv_indices[kv_start:kv_end].to(torch.long)
+ k_b = k_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ v_b = v_cache[page_ids].reshape(-1, num_kv_heads, head_dim).to(torch.float32)
+ num_kv_tokens = k_b.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 144 - 157, The prefill
path incorrectly treats kv_indices as token indices by indexing k_flat/v_flat
directly with page_ids (kv_indices) and computing num_kv_tokens from
page_ids.shape[0]; instead expand the selected pages to token-level rows first:
use kv_indices[kv_start:kv_end] to select page rows from the full per-page KV
buffer (not the flattened token axis), then concatenate or expand those page
rows into token-level k_b and v_b and compute num_kv_tokens from the resulting
expanded KV token rows; update usages around k_flat, v_flat, page_ids, k_b, v_b
and num_kv_tokens so page->token expansion happens before indexing the flattened
token axis.
| Kc_all = ckv_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_ckv] | ||
| Kp_all = kpe_cache.squeeze(1).to(torch.float32) # [num_pages, head_dim_kpe] | ||
|
|
||
| output = torch.zeros( | ||
| (batch_size, num_qo_heads, head_dim_ckv), dtype=torch.bfloat16, device=q_nope.device | ||
| ) | ||
| lse = torch.full( | ||
| (batch_size, num_qo_heads), -float("inf"), dtype=torch.float32, device=q_nope.device | ||
| ) | ||
|
|
||
| for b in range(batch_size): | ||
| page_beg = int(kv_indptr[b].item()) | ||
| page_end = int(kv_indptr[b + 1].item()) | ||
| if page_beg >= page_end: | ||
| output[b].zero_() | ||
| continue | ||
| tok_idx = kv_indices[page_beg:page_end].to(torch.long) | ||
| Kc = Kc_all[tok_idx] # [L, head_dim_ckv] | ||
| Kp = Kp_all[tok_idx] # [L, head_dim_kpe] | ||
| qn = q_nope[b].to(torch.float32) # [num_qo_heads, head_dim_ckv] | ||
| qp = q_pe[b].to(torch.float32) # [num_qo_heads, head_dim_kpe] | ||
| logits = ((qn @ Kc.T) + (qp @ Kp.T)) * sm_scale # [num_qo_heads, L] | ||
| lse[b] = torch.logsumexp(logits, dim=-1) / math.log(2.0) | ||
| output[b] = (torch.softmax(logits, dim=-1) @ Kc).to(torch.bfloat16) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Locate all MLA paged wrappers/templates and check whether page_size is constrained to 1.
rg -nP --type=py '\bmla_paged\b|\bpage_size\b|\bkv_indices\b|\bckv_cache\b|\bkpe_cache\b' -C2
# 2) Find explicit invariants/assertions for page_size==1 (if any).
rg -nP --type=py 'page_size\s*==\s*1|assert\s+.*page_size|squeeze\(\s*1\s*\)' -C2
# 3) Inspect fi_trace tests/fixtures for MLA coverage and page_size values used.
rg -nP --type=py --type=json 'mla_paged|page_size' tests -C2Repository: flashinfer-ai/flashinfer
Length of output: 50381
🏁 Script executed:
# Examine the specific file and lines mentioned in the review
head -n 520 flashinfer/trace/templates/attention.py | tail -n 180Repository: flashinfer-ai/flashinfer
Length of output: 6896
🏁 Script executed:
# Check for schema constraints on page_size for MLA
rg -nP 'page_size.*==.*1|page_size.*constraint|mla.*page_size|page_size.*schema' flashinfer -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 6949
🏁 Script executed:
# Look for comments or documentation about MLA page_size limitations
rg -nP 'MLA|mla_paged|mla' flashinfer/trace/templates/ -A5 -B5 | head -n 200Repository: flashinfer-ai/flashinfer
Length of output: 6947
🏁 Script executed:
# Check how the reference functions are used
rg -n "reference.*mla_paged|_mla_paged.*reference" flashinfer/trace/templates/attention.py -A2 -B2Repository: flashinfer-ai/flashinfer
Length of output: 620
🏁 Script executed:
# Look at test_fi_trace.py to see if it uses the reference implementations
grep -n "reference\|fi_trace.*mla" tests/test_fi_trace.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if TraceTemplate reference functions are ever called
rg -n "\.reference\(\|run_reference" flashinfer/trace -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check how TraceTemplate is used and if reference is called
rg -n "class TraceTemplate|def.*reference" flashinfer/trace -A5 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 6857
🏁 Script executed:
# Look at how fi_trace is implemented
find flashinfer -name "*.py" -exec grep -l "fi_trace" {} \; | head -10Repository: flashinfer-ai/flashinfer
Length of output: 270
🏁 Script executed:
# Check if there's validation code that calls the reference implementations
rg -n "reference" flashinfer/trace/core.py -A3 -B3Repository: flashinfer-ai/flashinfer
Length of output: 132
🏁 Script executed:
# Check the fi_trace.py file
cat -n flashinfer/fi_trace.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4308
🏁 Script executed:
# Check the template.py file for TraceTemplate
grep -n "class TraceTemplate\|reference" flashinfer/trace/template.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 476
🏁 Script executed:
# Look at how reference is used in tests
grep -n "reference\|\.run\|\.plan" tests/test_fi_trace.py -A2 -B2 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2127
🏁 Script executed:
# Check if reference functions are ever actually invoked/called
rg -n "template\.reference\|_mla_paged_decode_reference\|_mla_paged_prefill_reference" flashinfer -A3Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the actual CUDA/kernel implementations to see if they have page_size constraints
find flashinfer/mla -name "*.py" -o -name "*.cu" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 316
🏁 Script executed:
# Check the cute_dsl MLA implementation for page_size constraints
grep -n "page_size.*1\|squeeze" flashinfer/mla/cute_dsl/mla_decode_fp16.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 318
🏁 Script executed:
# Check what the actual MLA wrapper implementation looks like
grep -n "class.*MLA\|def.*run\|page_size" flashinfer/mla/_core.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1405
🏁 Script executed:
# Look at tests to see if they verify correctness with page_size != 1
grep -n "page_size.*64\|page_size.*16\|test_mla" tests/test_fi_trace.py -B5 -A10Repository: flashinfer-ai/flashinfer
Length of output: 4860
🏁 Script executed:
# Check if there are actual runtime tests for MLA with different page sizes
rg -n "BatchMLAPagedAttentionWrapper\|mla_paged" tests/ -A10 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if reference implementations are actually called during testing or validation
rg -n "reference\(" flashinfer -A5 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 3161
🏁 Script executed:
# Look at the actual wrapper implementation to see how it handles different page_size
sed -n '442,550p' flashinfer/mla/_core.pyRepository: flashinfer-ai/flashinfer
Length of output: 4684
🏁 Script executed:
# Check if squeeze is used anywhere in actual MLA implementations
rg -n "squeeze" flashinfer/mla/_core.pyRepository: flashinfer-ai/flashinfer
Length of output: 213
🏁 Script executed:
# Check if the reference implementations are actually ever invoked during normal operation
rg -n "\.reference\|call.*reference" flashinfer -A5 -B5Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify that actual kernels handle arbitrary page_size
grep -n "page_size\|ckv_cache\|kpe_cache" flashinfer/mla/_core.py | grep -A3 -B3 "cutlass\|cuda"Repository: flashinfer-ai/flashinfer
Length of output: 708
🏁 Script executed:
# Check what the cute_dsl actually does with page_size
head -n 100 flashinfer/mla/cute_dsl/mla_decode_fp16.py | grep -n "page_size"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if reference implementations are embedded in JSON or used elsewhere
grep -rn "reference" flashinfer/trace/template.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 451
🏁 Script executed:
# Look for how flashinfer-bench uses the reference
find . -name "*bench*" -type f | grep -E "\.py$|\.json$" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 515
🏁 Script executed:
# Verify that tests with page_size=64 actually pass/work
grep -A 20 "def test_mla_paged_fi_trace" tests/test_fi_trace.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 917
Reference implementations for MLA paged templates assume single-token pages but schema permits arbitrary page_size.
The _mla_paged_decode_reference() and _mla_paged_prefill_reference() functions use squeeze(1) to remove the page dimension, which is only valid when page_size == 1. However, the TraceTemplate schema defines page_size as a generic Const() without constraining it to 1. These reference implementations are embedded in the JSON trace output and used by flashinfer-bench for validation.
Tests already use page_size=64 (test_fi_trace.py:253, test_fi_trace.py:419), which would cause the reference implementations to fail or produce incorrect results.
Either enforce page_size == 1 as a schema constraint, or rewrite the reference implementations to properly flatten pages and tokens (e.g., Kc_all = ckv_cache.reshape(-1, head_dim_ckv) instead of squeeze) before the indexing operations.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/attention.py` around lines 359 - 383, The
reference implementations _mla_paged_decode_reference and
_mla_paged_prefill_reference assume page_size==1 by calling ckv_cache.squeeze(1)
and kpe_cache.squeeze(1); instead update these functions to flatten the page and
token dimensions so arbitrary page_size works (e.g., replace squeeze(1) with a
reshape/flatten to (-1, head_dim_ckv) for Kc_all and (-1, head_dim_kpe) for
Kp_all or use flatten(0,1)), ensuring subsequent indexing via kv_indices still
selects the correct token rows; alternatively, if you prefer to keep the current
code, enforce page_size==1 in the TraceTemplate schema, but do not leave
squeeze(1) as-is.
| "output": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads", "head_size"], | ||
| dtype_from="q", | ||
| description="Attention output. Shape follows num_v_heads in GVA mode.", | ||
| ), |
There was a problem hiding this comment.
The templates report output as dtype_from="q", but the references always emit bfloat16.
Lines 91, 208-210, and 377-379 cast the output tensors to torch.bfloat16, so these schemas become wrong as soon as q is not already bfloat16. The trace metadata should either fix the dtype to bfloat16 or model the real output-dtype control explicitly.
Suggested fix
- dtype_from="q",
+ dtype="bfloat16",
@@
- dtype_from="q",
+ dtype="bfloat16",
@@
- dtype_from="q",
+ dtype="bfloat16",Also applies to: 321-325, 486-490
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gdn.py` around lines 153 - 157, The Tensor schema
for the "output" entries in flashinfer/trace/templates/gdn.py currently uses
dtype_from="q" but the implementation always casts outputs to torch.bfloat16;
update the schema to reflect the real emitted dtype by replacing dtype_from="q"
with dtype="bfloat16" for the "output" Tensor declarations (the entries named
"output" in the template), or alternatively model the runtime control explicitly
if outputs can vary; make the same change for the other "output" Tensor
occurrences mentioned so the trace metadata matches the torch.bfloat16 casts in
the code.
| def _mm_reference(A, B): | ||
| return torch.matmul(A, B.T) | ||
|
|
||
|
|
||
| def _mm_fp8_reference(A, B): | ||
| """Dequantize FP8 block-scale inputs and compute C = A @ B.T. | ||
|
|
||
| B is in TRT-LLM block layout [K//block_size, N, block_size] and is | ||
| reshaped to [K, N] before the matmul. | ||
| """ | ||
| K_div_bs, N, block_size = B.shape | ||
| B_fp32 = B.reshape(K_div_bs * block_size, N).to(torch.float32) | ||
| A_fp32 = A.to(torch.float32) | ||
| return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16) |
There was a problem hiding this comment.
The GEMM references are transposing B after already modeling it as physical [K, N].
For mm_bf16 this is shape-invalid as soon as N != K, and the quantized helpers have the same problem after dequantization. Given these templates describe b as the physical [K, N] tensor, the reference path should multiply by B directly.
Suggested fix
def _mm_reference(A, B):
- return torch.matmul(A, B.T)
+ return torch.matmul(A, B)
@@
- return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16)
+ return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16)
@@
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+ return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)
@@
- return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16)
+ return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)Also applies to: 38-55, 57-85
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference GEMM
helpers currently transpose B (using .T) even though B is modeled as the
physical [K, N] tensor; update _mm_reference and _mm_fp8_reference (and the
other similar reference helpers in the file) to multiply A by B directly (remove
the .T and B_fp32.T), keeping the same dtype conversions and return types (e.g.,
_mm_fp8_reference should still dequantize to float32, matmul, then cast to
bfloat16), and update any docstrings/comments that incorrectly describe B as
needing transpose.
| def _mm_fp4_reference(A, B, a_descale, b_descale, block_size=16): | ||
| """Dequantize FP4 inputs and compute C = A @ B.T. | ||
|
|
||
| A and B are fp4 e2m1fn values packed two-per-byte as uint8. | ||
| a_descale: [M, K//block_size], b_descale: [K, N//block_size]. | ||
| The reference unpacks the nibbles and applies the block scales. | ||
| """ | ||
| def _unpack_fp4(packed, rows, cols): | ||
| # Each byte holds two fp4 nibbles (low nibble = first element). | ||
| lo = (packed & 0x0F).to(torch.float32) | ||
| hi = ((packed >> 4) & 0x0F).to(torch.float32) | ||
| # Interleave low/high nibbles along the last dimension. | ||
| out = torch.stack([lo, hi], dim=-1).reshape(rows, cols) | ||
| return out | ||
|
|
||
| M, K_packed = A.shape | ||
| K = K_packed * 2 | ||
| _, N_packed = B.shape | ||
| N = N_packed * 2 | ||
|
|
||
| A_fp32 = _unpack_fp4(A, M, K) | ||
| B_fp32 = _unpack_fp4(B, K, N) |
There was a problem hiding this comment.
mm_fp4_trace cannot infer the right logical axes from packed inputs.
Lines 72-78 make it clear the runtime tensors are packed uint8 shapes, but the template still declares a and b as [M, K] and [K, N]. fi_trace will therefore report halved or conflicting K/N values for real FP4 calls. This needs packed-dimension axes or a custom extractor before the public API can emit correct runtime traces.
Also applies to: 191-200
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gemm.py` around lines 57 - 78, The template
misdeclares packed uint8 inputs as logical FP4 shapes causing fi_trace to infer
wrong K/N; update the public trace signatures (or add a pre-trace extractor) so
the runtime sees packed dimensions: treat A and B as [M, K_packed] and
[K_packed, N_packed] (or expose an extractor that maps packed -> logical by
doubling the last axis) and propagate corrected logical axes into mm_fp4_trace
before calling _mm_fp4_reference/_unpack_fp4; apply the same change to the other
occurrence around the second block (the 191-200 region) and ensure
a_descale/b_descale shape metadata matches the packed-block layout.
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | ||
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | ||
|
|
||
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | ||
| ``@flashinfer_api`` so the correct template is selected at call time:: | ||
|
|
||
| @flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | ||
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | ||
| ... | ||
|
|
||
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | ||
| suppresses trace generation. | ||
| """ | ||
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | ||
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | ||
|
|
||
|
|
||
| # Expose all possible templates so _attach_fi_trace can auto-register them | ||
| # in _TRACE_REGISTRY for consistency testing. | ||
| trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( | ||
| _MOE_TRACE_BY_ROUTING_TYPE.values() | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file to see the current state
cat -n flashinfer/trace/templates/moe.py | sed -n '570,605p'Repository: flashinfer-ai/flashinfer
Length of output: 1461
🏁 Script executed:
# Check if there's a mypy configuration or pre-commit hooks
fd -type f -name ".pre-commit*" -o -name "mypy.ini" -o -name "pyproject.toml" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Look for mypy configuration in pyproject.toml or setup.cfg
fd -type f \( -name "pyproject.toml" -o -name "setup.cfg" \) -exec head -50 {} +Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
find . -maxdepth 2 -type f \( -name "pyproject.toml" -o -name ".pre-commit-config.yaml" -o -name "setup.cfg" \) 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 182
🏁 Script executed:
# Search for mypy configuration
rg -l "mypy|attr-defined" --type toml --type yaml --type cfg 2>/dev/null | head -10Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if there are any comments or errors mentioning attr-defined in the moe.py file
rg "attr-defined|setattr" flashinfer/trace/templates/moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Check if setattr is used elsewhere in the codebase for similar purposes
rg "setattr\(" flashinfer/ -A 2 | head -40Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat .pre-commit-config.yamlRepository: flashinfer-ai/flashinfer
Length of output: 1722
🏁 Script executed:
cat pyproject.tomlRepository: flashinfer-ai/flashinfer
Length of output: 3158
🏁 Script executed:
# Check if there are any .pyi stub files for this module
find flashinfer/trace -name "*.pyi" 2>/dev/nullRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for Protocol or TypedDict definitions related to function templates
rg "Protocol|TypedDict" flashinfer/trace/ --type python | head -20Repository: flashinfer-ai/flashinfer
Length of output: 98
🏁 Script executed:
# Check if _attach_fi_trace function expects templates attribute
rg "_attach_fi_trace" flashinfer/ -B 5 -A 10 | head -50Repository: flashinfer-ai/flashinfer
Length of output: 3287
🏁 Script executed:
# Check for other similar patterns where attributes are assigned to functions
rg "^\w+\.\w+\s*=\s*" flashinfer/trace/templates/ -A 0 | head -20Repository: flashinfer-ai/flashinfer
Length of output: 161
🏁 Script executed:
# Verify the exact mypy behavior - look at the api_logging.py more carefully
cat -n flashinfer/api_logging.py | sed -n '1,100p'Repository: flashinfer-ai/flashinfer
Length of output: 4356
🏁 Script executed:
# Look for any existing patterns where setattr might be used or recommended
rg "templates\s*=" flashinfer/ --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 460
🏁 Script executed:
# Check if there are any mypy: ignore comments or type: ignore in moe.py
rg "mypy|type.*ignore" flashinfer/trace/templates/moe.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look for the function signature and what it returns
cat -n flashinfer/trace/templates/moe.py | sed -n '577,592p'Repository: flashinfer-ai/flashinfer
Length of output: 881
🏁 Script executed:
# Check what TraceTemplate is to understand the type signature better
rg "class TraceTemplate|def TraceTemplate" flashinfer/ --type py -A 5 | head -30Repository: flashinfer-ai/flashinfer
Length of output: 380
🏁 Script executed:
# Verify the fix would work by checking how setattr is commonly used in the codebase
rg "setattr" flashinfer/ --type py -B 1 -A 1 | head -20Repository: flashinfer-ai/flashinfer
Length of output: 356
🏁 Script executed:
# Find where the .templates attribute is accessed
rg "\.templates" flashinfer/ --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 4375
🏁 Script executed:
# Check if there's any hasattr check
rg "hasattr.*templates" flashinfer/ --type pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at the complete _attach_fi_trace function to understand how templates is used
rg "_attach_fi_trace" flashinfer/api_logging.py -A 50 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 3908
Use setattr() to avoid mypy's attr-defined error.
The direct assignment at line 596 triggers mypy's strict mode because the function object has no declared templates attribute. Using setattr() preserves identical runtime behavior while satisfying type checking requirements, which is necessary for pre-commit validation.
Suggested fix
-trtllm_fp8_block_scale_moe_trace_dispatch.templates = list(
- _MOE_TRACE_BY_ROUTING_TYPE.values()
-)
+setattr(
+ trtllm_fp8_block_scale_moe_trace_dispatch,
+ "templates",
+ list(_MOE_TRACE_BY_ROUTING_TYPE.values()),
+)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | |
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | |
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | |
| ``@flashinfer_api`` so the correct template is selected at call time:: | |
| @flashinfer_api(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | |
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | |
| ... | |
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | |
| suppresses trace generation. | |
| """ | |
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | |
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | |
| # Expose all possible templates so _attach_fi_trace can auto-register them | |
| # in _TRACE_REGISTRY for consistency testing. | |
| trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( | |
| _MOE_TRACE_BY_ROUTING_TYPE.values() | |
| ) | |
| def trtllm_fp8_block_scale_moe_trace_dispatch(**kwargs): | |
| """Return the appropriate TraceTemplate for the given ``routing_method_type``. | |
| Pass this as ``trace=trtllm_fp8_block_scale_moe_trace_dispatch`` to | |
| ``@flashinfer_api`` so the correct template is selected at call time:: | |
| `@flashinfer_api`(trace=trtllm_fp8_block_scale_moe_trace_dispatch) | |
| def trtllm_fp8_block_scale_moe(..., routing_method_type: int = 0, ...): | |
| ... | |
| Returns ``None`` for ``RoutingMethodType.Unspecified`` (6), which | |
| suppresses trace generation. | |
| """ | |
| routing_method_type = int(kwargs.get("routing_method_type", 0)) | |
| return _MOE_TRACE_BY_ROUTING_TYPE.get(routing_method_type) | |
| # Expose all possible templates so _attach_fi_trace can auto-register them | |
| # in _TRACE_REGISTRY for consistency testing. | |
| setattr( | |
| trtllm_fp8_block_scale_moe_trace_dispatch, | |
| "templates", | |
| list(_MOE_TRACE_BY_ROUTING_TYPE.values()), | |
| ) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 596-596: mypy error: "Callable[[KwArg(Any)], Any]" has no attribute "templates" [attr-defined]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/moe.py` around lines 577 - 598, The direct
attribute assignment trtllm_fp8_block_scale_moe_trace_dispatch.templates causes
mypy attr-defined errors; replace that assignment with a setattr call to attach
the templates list at runtime (e.g., use
setattr(trtllm_fp8_block_scale_moe_trace_dispatch, "templates",
list(_MOE_TRACE_BY_ROUTING_TYPE.values()))), keeping the same value (list of
_MOE_TRACE_BY_ROUTING_TYPE.values()) and preserving behavior for
_attach_fi_trace registration.
| for json_key, descriptor in template.inputs.items(): | ||
| if isinstance(descriptor, Scalar): | ||
| if descriptor.optional: | ||
| continue | ||
| p = _resolved_param(json_key, descriptor) | ||
| kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0 | ||
|
|
There was a problem hiding this comment.
The E2E generator should not synthesize int32 inputs as zero.
This currently creates impossible definitions such as block_size=0, so assert_fi_trace_complete() can pass on traces that would never be valid at runtime. Use a positive default, ideally with per-parameter overrides for constrained scalars.
Suggested fix
+_INT_SAMPLE_DEFAULTS = {
+ "block_size": 16,
+ "top_k": 1,
+ "n_group": 1,
+ "topk_group": 1,
+}
+
@@
if isinstance(descriptor, Scalar):
if descriptor.optional:
continue
p = _resolved_param(json_key, descriptor)
- kwargs[p] = 0 if descriptor.dtype == "int32" else 1.0
+ if descriptor.dtype == "int32":
+ kwargs[p] = _INT_SAMPLE_DEFAULTS.get(p, axis_size)
+ else:
+ kwargs[p] = 1.0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/test_fi_trace_template_consistency.py` around lines 193 - 199,
The test E2E generator currently assigns 0 for int32 scalars in the loop over
template.inputs which can create impossible values (e.g., block_size=0); update
the assignment in the loop that inspects isinstance(descriptor, Scalar) and uses
_resolved_param(json_key, descriptor) so that int32 defaults are positive (e.g.,
1 or another small positive) and preferably support per-parameter overrides for
constrained scalars before populating kwargs; ensure any change keeps optional
descriptors skipped and preserves the dtype branch for non-int32 floats, so
assert_fi_trace_complete() validates realistic traces.
There was a problem hiding this comment.
Actionable comments posted: 8
♻️ Duplicate comments (12)
flashinfer/trace/templates/gdn.py (3)
165-169:⚠️ Potential issue | 🟠 MajorReport
outputasbfloat16in the schema.All three references cast their outputs to
torch.bfloat16, so inheriting dtype fromqmakes the trace metadata wrong whenever callers use another input dtype.Also applies to: 351-355, 537-541
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 165 - 169, The schema currently sets the attention "output" Tensor using dtype_from="q", which misreports dtype because outputs are cast to torch.bfloat16; update the Tensor definition for the "output" field in the GDN templates to use an explicit dtype of "bfloat16" (replace dtype_from="q" with dtype="bfloat16") for the occurrences around the shown block and the other two occurrences (near lines 351-355 and 537-541) so the trace metadata correctly reflects torch.bfloat16 outputs.
421-458:⚠️ Potential issue | 🟠 MajorPersist the updated pooled state before returning.
state_HVKis updated for every token, butfinal_stateis cloned frominitial_stateafter the loop and never receives those updates. The returnedfinal_stateis therefore stale, and the generated JSON fixture will be stale too.Suggested fix
- for b_idx in range(B): + final_state = initial_state.clone() + for b_idx in range(B): state_idx = int(initial_state_indices[b_idx].item()) state_HVK = ( initial_state[state_idx].clone().float().transpose(-1, -2) ) # [H,V,K] -> [H,K,V] @@ if cache_intermediate: intermediate_states_buffer[state_idx, t] = state_HVK.transpose( -1, -2 ) # [H,K,V] -> [H,V,K] - - final_state = initial_state.clone() + final_state[state_idx] = state_HVK.transpose(-1, -2) return output, final_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 421 - 458, The loop updates state_HVK per batch/state but final_state is created from initial_state and never updated, so return value is stale; update final_state with the pooled (transposed) state_HVK for each corresponding state index (initial_state_indices) after finishing updates for that state (or after the outer loops) so final_state[state_idx] = state_HVK.transpose(-1, -2) (match the same [H,V,K] ↔ [H,K,V] orientation used for initial_state/state_HVK) before returning output and final_state.
362-365:⚠️ Potential issue | 🟠 Major
gdn_prefill_traceneeds the same head-ratio constraints as the other GDN templates.The reference divides by
num_v_heads // num_q_headsandnum_v_heads // num_k_heads, but this template currently accepts layouts that make those expansions invalid.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gdn.py` around lines 362 - 365, The gdn_prefill_trace template is missing head-ratio validity checks: add constraints ensuring num_v_heads is divisible by num_q_heads and by num_k_heads (e.g., num_v_heads % num_q_heads == 0 and num_v_heads % num_k_heads == 0) so the downstream divisions (num_v_heads // num_q_heads and num_v_heads // num_k_heads) used elsewhere are valid; update the constraints list in gdn_prefill_trace to include these checks referencing the variables num_v_heads, num_q_heads, and num_k_heads.flashinfer/trace/templates/gemm.py (2)
180-217:⚠️ Potential issue | 🟠 MajorThe FP4 trace schema still advertises unpacked shapes.
AandBare packeduint8tensors at runtime, so exposing them as[M, K]and[K, N]makesfi_traceinfer the wrong dimensions for real FP4 calls. Model the packed axes explicitly or add an extractor that maps packed sizes back to logicalK/N.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 180 - 217, The mm_fp4_trace TraceTemplate currently lists inputs "A" and "B" with unpacked shapes ["M","K"] and ["K","N"], but at runtime these are packed uint8 FP4 buffers; update mm_fp4_trace so the Tensor entries for "A" and "B" describe the packed axes (e.g., K_packed/K_block or bytes per packed row) or add an extractor that converts the packed dimensions back to logical K and N (use the existing "block_size" Var/Scalar to compute K//block_size and N//block_size); specifically modify the Tensor definitions for "A" and "B" in mm_fp4_trace (and any related axis defs such as "K" or "N") so fi_trace will infer correct runtime shapes for FP4-packed inputs.
22-35:⚠️ Potential issue | 🟠 MajorMultiply by the physical
[K, N]weight matrix in these references.Each template models
Bas a physical[K, N]tensor, but the references all call... @ B.T. That breaksmm_bf16as soon asN != Kand skews the quantized references the same way.Suggested fix
def _mm_reference(A, B): - return torch.matmul(A, B.T) + return torch.matmul(A, B) @@ - return torch.matmul(A_fp32, B_fp32.T).to(torch.bfloat16) + return torch.matmul(A_fp32, B_fp32).to(torch.bfloat16) @@ - return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16) + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16) @@ - return torch.matmul(A_scaled, B_scaled.T).to(torch.bfloat16) + return torch.matmul(A_scaled, B_scaled).to(torch.bfloat16)Also applies to: 38-55, 57-86
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/gemm.py` around lines 22 - 35, The reference implementations currently multiply by B.T, but B represents the physical [K, N] weight matrix so using B.T swaps dims and breaks cases where N != K; update _mm_reference to compute torch.matmul(A, B) (not A @ B.T), and in _mm_fp8_reference reshape B into [K, N] (B_fp32 = B.reshape(K_div_bs * block_size, N)) and use torch.matmul(A_fp32, B_fp32) (remove the trailing .T), applying the same fix to the other reference helpers mentioned (the FP8 and bf16 variants in the file).flashinfer/trace/templates/attention.py (3)
140-169:⚠️ Potential issue | 🟠 MajorExpand selected pages before applying the prefill causal window.
The reference currently indexes
k_flat/v_flatwith page ids and setsnum_kv_tokens = page_ids.shape[0], so both the causal window and the gathered KV tensors are off bypage_sizefor real paged caches.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 140 - 169, The code is indexing k_flat/v_flat by page_ids and treating num_kv_tokens as page_ids.shape[0], which is incorrect for paged caches; you must expand page indices to per-token indices before building k_b/v_b and computing num_kv_tokens so the causal window and gathers operate at token granularity. Change the gather so that page_ids are multiplied/expanded by page_size into token_indices (e.g. token_indices = page_ids.unsqueeze(1)*page_size + torch.arange(page_size, device=...)) and then use those token_indices to index the original k_flat/v_flat (or reshape k_cache/v_cache into per-token and gather by token_indices) so k_b/v_b contain all tokens from the selected pages, set num_kv_tokens = token_indices.numel() (or actual token count if last page partial), and adjust uses of max_kv, delta, and slicing (k_b[:max_kv], v_b[:max_kv]) accordingly.
357-385:⚠️ Potential issue | 🟠 MajorThese MLA references still assume
page_size == 1.Both paths call
squeeze(1)on paged caches, but the schema accepts arbitrarypage_sizeand the tests already use larger values like 64. Flatten the page/token dimensions or constrain the template to single-token pages.Also applies to: 476-518
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 357 - 385, The code currently assumes page_size == 1 by calling ckv_cache.squeeze(1) and kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which collapses the page dimension; instead merge the page and token dimensions so arbitrary page_size works: replace the squeeze(1) usage with a reshape/view that flattens the first two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and kpe_cache.reshape(-1, head_dim_kpe)) and keep using kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx) to index into the flattened Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp = Kp_all[tok_idx] work for multi-token pages; apply the same change to the other block (lines ~476-518) where ckv_cache/kpe_cache are squeezed.
39-58:⚠️ Potential issue | 🟠 MajorTreat
kv_indicesas page ids in the decode reference.
kv_indicesare documented as page ids, but this code indexes the flattened token buffer with them. That only stays correct whenpage_size == 1; otherwise the reference gathers the wrong KV rows.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 39 - 58, kv_indices are page IDs but the code treats them as token indices when building token_ids; fix by expanding each page id into its page_size token-row indices before indexing k_flat/v_flat. In the loop over b replace the current token_ids = kv_indices[page_start:page_end].to(torch.long) with logic that maps each page id p to the contiguous token index range p*page_size .. (p+1)*page_size-1 (preserving dtype/device), then flatten that to a 1D tensor and use it to build k_b and v_b so k_b/v_b remain shaped [T, num_kv_heads, head_dim]; keep using k_flat/v_flat, kv_indptr, kv_indices, page_size, k_b, v_b, token_ids, and ensure device/torch.long handling remains correct.tests/trace/example.py (1)
54-373:⚠️ Potential issue | 🟠 MajorPytest won’t collect this trace example.
Everything after the env setup runs as import-time side effects, but the file defines no
test_...entrypoint. CI will never exercise trace generation unless this body is moved into a real pytest test and the script entrypoint is kept separate.As per coding guidelines,
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/example.py` around lines 54 - 373, The file runs the entire trace-generation body at import time so pytest won't collect it; keep the environment setup (the FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the imports as-is but move everything that performs work (starting from device/WORKSPACE and all calls that exercise flashinfer APIs, e.g., the loops that call flashinfer.rmsnorm, flashinfer.mm_bf16, BatchDecodeWithPagedKVCacheWrapper.plan/run, BatchPrefillWithPagedKVCacheWrapper.plan/run, BatchPrefillWithRaggedKVCacheWrapper.plan/run, BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*, flashinfer.fused_moe.* and the final JSON summary) into a pytest test function named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it; also add a minimal if __name__ == "__main__" guard to call that function when run as a script so the example remains runnable standalone.flashinfer/trace/templates/moe.py (3)
648-652:⚠️ Potential issue | 🟡 MinorUse
setattr()to avoid mypyattr-definederror.The direct attribute assignment triggers mypy's
attr-definederror because function objects don't have a declaredtemplatesattribute. Usesetattr()to preserve runtime behavior while satisfying type checking.Suggested fix
# Expose all possible templates so _attach_fi_trace can auto-register them # in _TRACE_REGISTRY for consistency testing. -trtllm_fp8_block_scale_moe_trace_dispatch.templates = list( - _MOE_TRACE_BY_ROUTING_TYPE.values() -) +setattr( + trtllm_fp8_block_scale_moe_trace_dispatch, + "templates", + list(_MOE_TRACE_BY_ROUTING_TYPE.values()), +)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 648 - 652, The assignment to add a dynamic attribute on the function trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors; replace the direct assignment with a setattr call so the templates attribute is attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch, "templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior while satisfying type checking.
25-27:⚠️ Potential issue | 🟠 MajorHardcoded
HandIconstants make reference execution shape-fragile.The module-level constants
H=7168andI=2048are used in_fp8_moe_run_expertsbut the actual hidden_size and intermediate_size can vary. This will produce incorrect results or errors for other valid MoE configurations.Suggested fix — derive H and I from tensor shapes
-H = 7168 -I = 2048 BLOCK = 128 `@torch.no_grad`() def _fp8_moe_run_experts( hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, weights, topk_idx, local_expert_offset, E_global, ): - T = hidden_states.shape[0] + T, H = hidden_states.shape + I = gemm2_weights.shape[2] E_local = gemm1_weights.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 25 - 27, The module currently hardcodes H=7168 and I=2048 which breaks _fp8_moe_run_experts for models with different hidden/intermediate sizes; change the code to derive hidden_size and intermediate_size at runtime from tensor shapes (e.g., infer hidden_size from the input/hidden tensor shape[-1] or the expert weight shapes, and infer intermediate_size from the feedforward weight/output shapes) and replace uses of H and I with those derived values (also ensure BLOCK is computed/validated against hidden_size if needed); update all references in _fp8_moe_run_experts to use the derived variables so the function works for arbitrary MoE shapes.
126-131:⚠️ Potential issue | 🟠 MajorReference implementations hardcode routing parameters that should be configurable.
TOP_K=8,N_GROUP=8,TOPK_GROUP=4are hardcoded, but the public API accepts these as arguments. If these references are used for numerical validation, they will only be correct for one configuration.If these references are only for schema validation (not numerical correctness), consider adding a comment to clarify their limited scope.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 126 - 131, The template hardcodes routing parameters TOP_K, N_GROUP, TOPK_GROUP which conflict with the public API; update the code that defines TOP_K, N_GROUP, TOPK_GROUP to read the corresponding function arguments (e.g., top_k, n_group, topk_group) or the routing parameters object instead of fixed literals so the template matches whatever configuration is passed via routing_logits' caller (or if these values are truly only for shape/schema checks, replace the literals with a clarifying comment near TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used only for schema validation and not numeric correctness). Ensure you change the occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with routing_logits, E_global, T) accordingly.
🧹 Nitpick comments (8)
flashinfer/trace/template.py (3)
473-473: Consider using list unpacking for slightly cleaner syntax.Ruff suggests
[f"fi_api:{fi_api}", *template.tags]instead of list concatenation.Suggested fix
- all_tags = [f"fi_api:{fi_api}"] + template.tags + all_tags = [f"fi_api:{fi_api}", *template.tags]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` at line 473, Replace the list concatenation that builds all_tags with list unpacking for clearer syntax: in the function/block where variable all_tags is assigned (currently using all_tags = [f"fi_api:{fi_api}"] + template.tags), change it to construct the list using [f"fi_api:{fi_api}", *template.tags] so it directly prepends the formatted fi_api tag to template.tags; keep the same variable name and semantics.
426-443: Auto-infer dtype uses first matching input — document this behavior.The auto-inference logic selects the dtype from the first input tensor with overlapping dimension names (line 443
break). This is a reasonable heuristic, but if multiple inputs have overlapping dims with different dtypes, the choice is arbitrary. Consider adding a brief inline comment noting this precedence.Suggested documentation
else: - # Auto-infer: find first input tensor with overlapping dims + # Auto-infer: use dtype from first input tensor with overlapping + # dims. If multiple inputs overlap, precedence follows dict order. dtype = "unknown"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 426 - 443, The auto-infer branch in template.py sets dtype to the first matching input's type (looping over template.inputs, checking Tensor instances and overlapping descriptor.dim_names, using _get_tensor and _dtype_str, then break), which is arbitrary when multiple inputs overlap; add a concise inline comment near this logic (around the loop and the break) stating that this chooses the first matching input's dtype as the precedence rule and that other overlapping inputs may be ignored, so callers should avoid ambiguous multiple-dtype overlaps or explicitly provide dtype to override; keep the comment short and reference template.inputs, Tensor, descriptor, _get_tensor, and _dtype_str.
370-378: Silent exception swallowing may hide bugs during axis extraction.The bare
except Exception: passat lines 376-377 silently ignores all errors during axis value extraction. While this provides robustness, it can hide bugs in extractor logic or unexpected input types. Consider at minimum logging at debug level.Suggested fix
+import logging + +_logger = logging.getLogger(__name__) + # In fi_trace function: for axis_name, extractor in axis_extractors.items(): try: val = extractor(kwargs) if val is not None: axis_values[axis_name] = val - except Exception: - pass + except Exception as exc: + _logger.debug( + "Axis extraction failed for %s: %s", axis_name, exc + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/template.py` around lines 370 - 378, The code silently swallows all exceptions when running axis extractor functions (axis_extractors -> extractor(kwargs) populating axis_values), which hides bugs; change the bare "except Exception: pass" to catch the exception as e and log it at debug level (e.g., logger.debug("axis extractor %s failed for kwargs=%s: %s", axis_name, kwargs, e, exc_info=True)) so failures are recorded but extraction remains robust, and if there is no existing logger in this module create one via logging.getLogger(__name__) and import logging.tests/trace/test_fi_trace_template_consistency.py (4)
399-408: Variablekshadows the tensorkdefined earlier.The loop variable
kat line 400 shadows the tensorkdefined at line 391. While this doesn't affect correctness (the tensor is no longer needed at this point), it reduces readability.Suggested fix
non_optional_unknown = [ - k - for k, v in defn["inputs"].items() - if isinstance(v, dict) - and v.get("dtype") == "unknown" - and not v.get("optional", False) + key + for key, val in defn["inputs"].items() + if isinstance(val, dict) + and val.get("dtype") == "unknown" + and not val.get("optional", False) ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 399 - 408, The loop variable k in the comprehension that builds non_optional_unknown shadows the tensor named k earlier; rename the loop variable (e.g., to input_name or inp_key) used in the comprehension and in the f-string so it no longer collides with the tensor k, updating the comprehension over defn["inputs"].items() and the f"Non-optional inputs with unknown dtype: {...}" reference accordingly.
495-496: Use a raw string for the regex pattern.The pattern contains backslashes and should be a raw string to avoid unintended escapes and satisfy Ruff RUF043.
Suggested fix
- with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"): + with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 495 - 496, The regex passed to pytest.raises should be a raw string to avoid accidental escape sequences; update the call to pytest.raises(AssertionError, match=...) used around assert_template_signature_consistency(func, broken, label="meta-test") so the match argument is a raw string literal (prefix it with r, e.g. r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure the pattern is interpreted correctly.
430-460: Rename ambiguous variableIin the MoE routing test.Ruff flags
Ias ambiguous (E741). Consider renaming tointermediateorinter_sizefor clarity.Suggested fix
- T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128 + T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128 defn = trtllm_fp8_block_scale_moe.fi_trace( routing_logits=torch.zeros(T, E, dtype=torch.float32), routing_bias=torch.zeros(E, dtype=torch.bfloat16), hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn), hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32), - gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn), - gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32), - gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn), - gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32), + gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn), + gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32), + gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn), + gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32), num_experts=E, top_k=top_k, - intermediate_size=I, + intermediate_size=INTER,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 430 - 460, The test function test_fi_trace_complete_moe_routing uses a single-letter variable I (intermediate size) which triggers an ambiguity lint (E741); rename I to a descriptive identifier (e.g., inter_size or intermediate) and update all references inside the function and the fi_trace(...) call (intermediate_size=I, shapes using I, 2 * I etc.) so the values and assertions remain identical but the variable name is clear and matches usage in trtllm_fp8_block_scale_moe.fi_trace.
369-370: Rename ambiguous loop variablelto improve readability.Ruff flags
las ambiguous (E741) because it can be confused with1. Consider renaming tolblorlabel.Suggested fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP] +_E2E_PAIRS = [(f, t, lbl) for f, t, lbl in _ALL_PAIRS if lbl not in _E2E_SKIP]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 369 - 370, The list comprehension _E2E_PAIRS uses an ambiguous loop variable named "l"; rename it to a clearer identifier (e.g., "label" or "lbl") in the comprehension and update the subsequent _E2E_IDS comprehension to unpack/use that new name so both _E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in _E2E_SKIP] and _E2E_IDS = [label for _, _, label in _E2E_PAIRS] remain consistent.flashinfer/trace/templates/moe.py (1)
85-88: Rename ambiguous variableOto improve readability.Ruff flags
Oas ambiguous (E741) because it can be confused with0. Consider renaming tooutoroutput_e.Suggested fix
- O = (silu_X2 * X1).matmul(W2[le].t()) + expert_out = (silu_X2 * X1).matmul(W2[le].t()) # per-expert contribution weight for each token w_tok = weights.index_select(0, token_idx) # find which slot in topk_idx[token_idx] corresponds to ge match = (topk_idx.index_select(0, token_idx) == ge).float() w_e = (w_tok * match).sum(dim=1) - output.index_add_(0, token_idx, O * w_e.unsqueeze(1)) + output.index_add_(0, token_idx, expert_out * w_e.unsqueeze(1))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 85 - 88, Rename the ambiguous variable O used in the moe attention feedforward block to a clearer name (e.g., out or output_e) to avoid confusion with the digit zero; update the assignment and any subsequent uses where O appears (the expression "(silu_X2 * X1).matmul(W2[le].t())") and ensure references to G1, X1, X2, silu_X2, W13, W2, A_e, and le remain correct with the new variable name.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1510-1531: The wrapper created by _attach_fi_trace (and returned
by flashinfer_api) adds runtime cost even when tracing/LOGLEVEL=0; change
_attach_fi_trace so that if tracing is disabled (i.e., _is_trace_dump_enabled()
is False and caller requested zero-overhead) it does not create
&_auto_dump_wrapper but instead attaches fi_trace to the original callable via
setattr(original, "fi_trace", fi_trace_fn) and returns original; otherwise keep
the current wrapper behavior. Also avoid direct attribute assignment on
Callable-typed objects that triggers mypy attr-defined errors by using
setattr(original, "fi_trace", fi_trace_fn) or by casting to Any/creating a small
Protocol for fi_trace to satisfy type-checkers (e.g., cast(original, Any) or
define Protocol with fi_trace) so the pipeline no longer errors.
- Around line 1508-1531: Replace direct attribute assignments to .fi_trace with
setattr to avoid mypy attr-defined errors: where the diff sets wrapped.fi_trace
= fi_trace_fn and _auto_dump_wrapper.fi_trace = fi_trace_fn, change those direct
assignments to use setattr(wrapped, "fi_trace", fi_trace_fn) and
setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the same semantics
(assign the fi_trace_fn callable) and leave other code in _auto_dump_wrapper,
_sig, and fi_trace_fn unchanged.
In `@flashinfer/fi_trace.py`:
- Around line 238-285: The function fi_trace currently types func_or_method as
Callable but relies on the object (actual_func) exposing a .fi_trace attribute;
update the typing to make that contract explicit by introducing a Protocol
(e.g., TracedCallable with a fi_trace(self, save_dir: Optional[Union[str, Path]]
= None, **kwargs) -> Dict[str, Any]) and use that Protocol as the type for
func_or_method (or cast actual_func to TracedCallable before accessing
.fi_trace); ensure the Protocol signature matches how trace_fn is called in
fi_trace and import typing.Protocol and any necessary types so mypy recognizes
the requirement.
- Around line 103-110: The import line bringing in Const, Scalar, Tensor,
TraceTemplate, and Var from .trace.template is unused in build_fi_trace_fn and
causing Ruff F401 warnings; remove those five names (or the whole legacy import
if nothing else from that module is used) so only needed symbols remain imported
in flashinfer/fi_trace.py and eliminate the unused imports Const, Scalar,
Tensor, TraceTemplate, Var from the import statement that currently appears
alongside build_fi_trace_fn.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 502-546: The template schema is missing the disable_state_update
input required by gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry
named "disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 384-394: Remove the unused import gqa_paged_decode_trace from the
test; locate the import statement that reads "from
flashinfer.trace.templates.attention import gqa_paged_decode_trace" and delete
it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).
- Around line 309-321: The import of flashinfer.sampling is unused and flagged
by pre-commit; either remove the import statement for flashinfer.sampling or
ensure it registers decorators used by _TRACE_REGISTRY (so the import has side
effects). Locate the import block containing flashinfer.sampling (near imports
for flashinfer.decode, flashinfer.gdn_decode, etc.) and delete the
flashinfer.sampling line if no decorated functions from that module are expected
to be registered, otherwise import the specific symbols that cause registration
or add a comment explaining the necessary side-effect to avoid removal by
linters.
In `@tests/trace/test_fi_trace.py`:
- Line 20: Remove the unused top-level import of pytest in
tests/trace/test_fi_trace.py: delete the line "import pytest" since the file
relies on pytest fixtures (tmp_path, monkeypatch) provided by pytest's runtime
and does not reference the pytest symbol directly; ensure no other code in the
module uses the pytest name before committing.
---
Duplicate comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 140-169: The code is indexing k_flat/v_flat by page_ids and
treating num_kv_tokens as page_ids.shape[0], which is incorrect for paged
caches; you must expand page indices to per-token indices before building
k_b/v_b and computing num_kv_tokens so the causal window and gathers operate at
token granularity. Change the gather so that page_ids are multiplied/expanded by
page_size into token_indices (e.g. token_indices =
page_ids.unsqueeze(1)*page_size + torch.arange(page_size, device=...)) and then
use those token_indices to index the original k_flat/v_flat (or reshape
k_cache/v_cache into per-token and gather by token_indices) so k_b/v_b contain
all tokens from the selected pages, set num_kv_tokens = token_indices.numel()
(or actual token count if last page partial), and adjust uses of max_kv, delta,
and slicing (k_b[:max_kv], v_b[:max_kv]) accordingly.
- Around line 357-385: The code currently assumes page_size == 1 by calling
ckv_cache.squeeze(1) and kpe_cache.squeeze(1) (variables Kc_all/Kp_all) which
collapses the page dimension; instead merge the page and token dimensions so
arbitrary page_size works: replace the squeeze(1) usage with a reshape/view that
flattens the first two dims (e.g. ckv_cache.reshape(-1, head_dim_ckv) and
kpe_cache.reshape(-1, head_dim_kpe)) and keep using
kv_indices[kv_indptr[b]:kv_indptr[b+1]] (tok_idx) to index into the flattened
Kc_all/Kp_all so Kc = Kc_all[tok_idx] and Kp = Kp_all[tok_idx] work for
multi-token pages; apply the same change to the other block (lines ~476-518)
where ckv_cache/kpe_cache are squeezed.
- Around line 39-58: kv_indices are page IDs but the code treats them as token
indices when building token_ids; fix by expanding each page id into its
page_size token-row indices before indexing k_flat/v_flat. In the loop over b
replace the current token_ids = kv_indices[page_start:page_end].to(torch.long)
with logic that maps each page id p to the contiguous token index range
p*page_size .. (p+1)*page_size-1 (preserving dtype/device), then flatten that to
a 1D tensor and use it to build k_b and v_b so k_b/v_b remain shaped [T,
num_kv_heads, head_dim]; keep using k_flat/v_flat, kv_indptr, kv_indices,
page_size, k_b, v_b, token_ids, and ensure device/torch.long handling remains
correct.
In `@flashinfer/trace/templates/gdn.py`:
- Around line 165-169: The schema currently sets the attention "output" Tensor
using dtype_from="q", which misreports dtype because outputs are cast to
torch.bfloat16; update the Tensor definition for the "output" field in the GDN
templates to use an explicit dtype of "bfloat16" (replace dtype_from="q" with
dtype="bfloat16") for the occurrences around the shown block and the other two
occurrences (near lines 351-355 and 537-541) so the trace metadata correctly
reflects torch.bfloat16 outputs.
- Around line 421-458: The loop updates state_HVK per batch/state but
final_state is created from initial_state and never updated, so return value is
stale; update final_state with the pooled (transposed) state_HVK for each
corresponding state index (initial_state_indices) after finishing updates for
that state (or after the outer loops) so final_state[state_idx] =
state_HVK.transpose(-1, -2) (match the same [H,V,K] ↔ [H,K,V] orientation used
for initial_state/state_HVK) before returning output and final_state.
- Around line 362-365: The gdn_prefill_trace template is missing head-ratio
validity checks: add constraints ensuring num_v_heads is divisible by
num_q_heads and by num_k_heads (e.g., num_v_heads % num_q_heads == 0 and
num_v_heads % num_k_heads == 0) so the downstream divisions (num_v_heads //
num_q_heads and num_v_heads // num_k_heads) used elsewhere are valid; update the
constraints list in gdn_prefill_trace to include these checks referencing the
variables num_v_heads, num_q_heads, and num_k_heads.
In `@flashinfer/trace/templates/gemm.py`:
- Around line 180-217: The mm_fp4_trace TraceTemplate currently lists inputs "A"
and "B" with unpacked shapes ["M","K"] and ["K","N"], but at runtime these are
packed uint8 FP4 buffers; update mm_fp4_trace so the Tensor entries for "A" and
"B" describe the packed axes (e.g., K_packed/K_block or bytes per packed row) or
add an extractor that converts the packed dimensions back to logical K and N
(use the existing "block_size" Var/Scalar to compute K//block_size and
N//block_size); specifically modify the Tensor definitions for "A" and "B" in
mm_fp4_trace (and any related axis defs such as "K" or "N") so fi_trace will
infer correct runtime shapes for FP4-packed inputs.
- Around line 22-35: The reference implementations currently multiply by B.T,
but B represents the physical [K, N] weight matrix so using B.T swaps dims and
breaks cases where N != K; update _mm_reference to compute torch.matmul(A, B)
(not A @ B.T), and in _mm_fp8_reference reshape B into [K, N] (B_fp32 =
B.reshape(K_div_bs * block_size, N)) and use torch.matmul(A_fp32, B_fp32)
(remove the trailing .T), applying the same fix to the other reference helpers
mentioned (the FP8 and bf16 variants in the file).
In `@flashinfer/trace/templates/moe.py`:
- Around line 648-652: The assignment to add a dynamic attribute on the function
trtllm_fp8_block_scale_moe_trace_dispatch causes mypy attr-defined errors;
replace the direct assignment with a setattr call so the templates attribute is
attached at runtime (e.g., setattr(trtllm_fp8_block_scale_moe_trace_dispatch,
"templates", list(_MOE_TRACE_BY_ROUTING_TYPE.values()))) to preserve behavior
while satisfying type checking.
- Around line 25-27: The module currently hardcodes H=7168 and I=2048 which
breaks _fp8_moe_run_experts for models with different hidden/intermediate sizes;
change the code to derive hidden_size and intermediate_size at runtime from
tensor shapes (e.g., infer hidden_size from the input/hidden tensor shape[-1] or
the expert weight shapes, and infer intermediate_size from the feedforward
weight/output shapes) and replace uses of H and I with those derived values
(also ensure BLOCK is computed/validated against hidden_size if needed); update
all references in _fp8_moe_run_experts to use the derived variables so the
function works for arbitrary MoE shapes.
- Around line 126-131: The template hardcodes routing parameters TOP_K, N_GROUP,
TOPK_GROUP which conflict with the public API; update the code that defines
TOP_K, N_GROUP, TOPK_GROUP to read the corresponding function arguments (e.g.,
top_k, n_group, topk_group) or the routing parameters object instead of fixed
literals so the template matches whatever configuration is passed via
routing_logits' caller (or if these values are truly only for shape/schema
checks, replace the literals with a clarifying comment near
TOP_K/N_GROUP/TOPK_GROUP stating they are placeholder defaults used only for
schema validation and not numeric correctness). Ensure you change the
occurrences of TOP_K, N_GROUP, TOPK_GROUP in this module (referenced with
routing_logits, E_global, T) accordingly.
In `@tests/trace/example.py`:
- Around line 54-373: The file runs the entire trace-generation body at import
time so pytest won't collect it; keep the environment setup (the
FLASHINFER_TRACE_* os.environ lines and SAVE_DIR) and the imports as-is but move
everything that performs work (starting from device/WORKSPACE and all calls that
exercise flashinfer APIs, e.g., the loops that call flashinfer.rmsnorm,
flashinfer.mm_bf16, BatchDecodeWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithPagedKVCacheWrapper.plan/run,
BatchPrefillWithRaggedKVCacheWrapper.plan/run,
BatchMLAPagedAttentionWrapper.plan/run, flashinfer.gdn_decode.*,
flashinfer.fused_moe.* and the final JSON summary) into a pytest test function
named with a test_ prefix (e.g., test_generate_fi_traces) so CI will execute it;
also add a minimal if __name__ == "__main__" guard to call that function when
run as a script so the example remains runnable standalone.
---
Nitpick comments:
In `@flashinfer/trace/template.py`:
- Line 473: Replace the list concatenation that builds all_tags with list
unpacking for clearer syntax: in the function/block where variable all_tags is
assigned (currently using all_tags = [f"fi_api:{fi_api}"] + template.tags),
change it to construct the list using [f"fi_api:{fi_api}", *template.tags] so it
directly prepends the formatted fi_api tag to template.tags; keep the same
variable name and semantics.
- Around line 426-443: The auto-infer branch in template.py sets dtype to the
first matching input's type (looping over template.inputs, checking Tensor
instances and overlapping descriptor.dim_names, using _get_tensor and
_dtype_str, then break), which is arbitrary when multiple inputs overlap; add a
concise inline comment near this logic (around the loop and the break) stating
that this chooses the first matching input's dtype as the precedence rule and
that other overlapping inputs may be ignored, so callers should avoid ambiguous
multiple-dtype overlaps or explicitly provide dtype to override; keep the
comment short and reference template.inputs, Tensor, descriptor, _get_tensor,
and _dtype_str.
- Around line 370-378: The code silently swallows all exceptions when running
axis extractor functions (axis_extractors -> extractor(kwargs) populating
axis_values), which hides bugs; change the bare "except Exception: pass" to
catch the exception as e and log it at debug level (e.g., logger.debug("axis
extractor %s failed for kwargs=%s: %s", axis_name, kwargs, e, exc_info=True)) so
failures are recorded but extraction remains robust, and if there is no existing
logger in this module create one via logging.getLogger(__name__) and import
logging.
In `@flashinfer/trace/templates/moe.py`:
- Around line 85-88: Rename the ambiguous variable O used in the moe attention
feedforward block to a clearer name (e.g., out or output_e) to avoid confusion
with the digit zero; update the assignment and any subsequent uses where O
appears (the expression "(silu_X2 * X1).matmul(W2[le].t())") and ensure
references to G1, X1, X2, silu_X2, W13, W2, A_e, and le remain correct with the
new variable name.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 399-408: The loop variable k in the comprehension that builds
non_optional_unknown shadows the tensor named k earlier; rename the loop
variable (e.g., to input_name or inp_key) used in the comprehension and in the
f-string so it no longer collides with the tensor k, updating the comprehension
over defn["inputs"].items() and the f"Non-optional inputs with unknown dtype:
{...}" reference accordingly.
- Around line 495-496: The regex passed to pytest.raises should be a raw string
to avoid accidental escape sequences; update the call to
pytest.raises(AssertionError, match=...) used around
assert_template_signature_consistency(func, broken, label="meta-test") so the
match argument is a raw string literal (prefix it with r, e.g.
r"param=.*hidden_state.*not found") to satisfy Ruff RUF043 and ensure the
pattern is interpreted correctly.
- Around line 430-460: The test function test_fi_trace_complete_moe_routing uses
a single-letter variable I (intermediate size) which triggers an ambiguity lint
(E741); rename I to a descriptive identifier (e.g., inter_size or intermediate)
and update all references inside the function and the fi_trace(...) call
(intermediate_size=I, shapes using I, 2 * I etc.) so the values and assertions
remain identical but the variable name is clear and matches usage in
trtllm_fp8_block_scale_moe.fi_trace.
- Around line 369-370: The list comprehension _E2E_PAIRS uses an ambiguous loop
variable named "l"; rename it to a clearer identifier (e.g., "label" or "lbl")
in the comprehension and update the subsequent _E2E_IDS comprehension to
unpack/use that new name so both _E2E_PAIRS = [(f, t, label) for f, t, label in
_ALL_PAIRS if label not in _E2E_SKIP] and _E2E_IDS = [label for _, _, label in
_E2E_PAIRS] remain consistent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a970c3e2-8b74-4f9f-b236-d08657910713
📥 Commits
Reviewing files that changed from the base of the PR and between f7e2129265f21b39f8b8f460ab9cb59648c88322 and 2f4aceb637fabe2f75c180b8086e97f602b147d0.
📒 Files selected for processing (32)
flashinfer/api_logging.pyflashinfer/fi_trace.pyflashinfer/trace/template.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/gdn.pyflashinfer/trace/templates/gemm.pyflashinfer/trace/templates/moe.pytests/trace/example.pytests/trace/fi_trace_out/fused_add_rmsnorm_h5120.jsontests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.jsontests/trace/fi_trace_out/gdn_mtp_qk4_v8_d128.jsontests/trace/fi_trace_out/gemm_bf16_N256_K7168.jsontests/trace/fi_trace_out/gemm_bf16_N4096_K4096.jsontests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.jsontests/trace/fi_trace_out/gemm_fp8_N1536_K7168.jsontests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.jsontests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.jsontests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/rmsnorm_h4096.jsontests/trace/fi_trace_out/rmsnorm_h7168.jsontests/trace/fi_trace_out/top_k_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_k_top_p_sampling_v151936.jsontests/trace/fi_trace_out/top_p_sampling_v128256.jsontests/trace/fi_trace_out/top_p_sampling_v151936.jsontests/trace/test_fi_trace.pytests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (15)
- tests/trace/fi_trace_out/gemm_bf16_N256_K7168.json
- tests/trace/fi_trace_out/rmsnorm_h4096.json
- tests/trace/fi_trace_out/rmsnorm_h7168.json
- tests/trace/fi_trace_out/gemm_bf16_N4096_K4096.json
- tests/trace/fi_trace_out/gemm_fp8_N1536_K7168.json
- tests/trace/fi_trace_out/fused_add_rmsnorm_h5120.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_mxfp8_N4096_K4096.json
- tests/trace/fi_trace_out/top_p_sampling_v151936.json
- tests/trace/fi_trace_out/top_k_sampling_v128256.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps64.json
- tests/trace/fi_trace_out/top_k_top_p_sampling_v151936.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/gqa_paged_prefill_h32_kv8_d128_ps16.json
🚧 Files skipped from review as they are similar to previous changes (6)
- tests/trace/fi_trace_out/top_p_sampling_v128256.json
- tests/trace/fi_trace_out/gemm_fp4_N2048_K7168_block_size16.json
- tests/trace/fi_trace_out/gqa_ragged_h32_kv8_d128.json
- tests/trace/fi_trace_out/gdn_decode_qk4_v8_d128.json
- tests/trace/fi_trace_out/gqa_paged_decode_h32_kv8_d128_ps16.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
| wrapped.fi_trace = fi_trace_fn | ||
|
|
||
| # Auto-dump wrapper: checked lazily at call time so that callers | ||
| # can set FLASHINFER_TRACE_DUMP after importing flashinfer (e.g. | ||
| # when running via ``python -m``). | ||
| _inner = wrapped | ||
| _sig = inspect.signature(original) | ||
|
|
||
| @functools.wraps(_inner) | ||
| def _auto_dump_wrapper(*args, **kwargs): | ||
| # Generate trace BEFORE the actual call (crash-safe: schema | ||
| # depends only on input shapes/dtypes, not on whether the | ||
| # computation succeeds). | ||
| if _is_trace_dump_enabled(): | ||
| try: | ||
| bound = _sig.bind(*args, **kwargs) | ||
| bound.apply_defaults() | ||
| fi_trace_fn(**dict(bound.arguments)) | ||
| except Exception: | ||
| pass | ||
| return _inner(*args, **kwargs) | ||
|
|
||
| _auto_dump_wrapper.fi_trace = fi_trace_fn | ||
| return _auto_dump_wrapper |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -maxdepth 5 -name "*.py" -path "*flashinfer*" | grep -i api | head -10Repository: flashinfer-ai/flashinfer
Length of output: 95
🏁 Script executed:
ls -la flashinfer/ 2>/dev/null | head -20 || echo "Directory not found"Repository: flashinfer-ai/flashinfer
Length of output: 1275
🏁 Script executed:
git ls-files "*api_logging*"Repository: flashinfer-ai/flashinfer
Length of output: 93
🏁 Script executed:
sed -n '1508,1540p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1518
🏁 Script executed:
grep -n "\.fi_trace\s*=" flashinfer/api_logging.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
sed -n '1490,1545p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 2559
🏁 Script executed:
head -100 flashinfer/api_logging.py | grep -E "^(import|from|class|def|__all__)"Repository: flashinfer-ai/flashinfer
Length of output: 406
🏁 Script executed:
grep -B 30 "wrapped.fi_trace = fi_trace_fn" flashinfer/api_logging.py | head -50Repository: flashinfer-ai/flashinfer
Length of output: 1594
🏁 Script executed:
sed -n '1450,1510p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 3083
🏁 Script executed:
find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
grep -n "def.*wrapped" flashinfer/api_logging.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
cat pyproject.toml | grep -A 20 "\[tool.mypy\]"Repository: flashinfer-ai/flashinfer
Length of output: 429
🏁 Script executed:
cat pyproject.toml | grep -A 20 "lint\|mypy"Repository: flashinfer-ai/flashinfer
Length of output: 712
🏁 Script executed:
sed -n '1450,1560p' flashinfer/api_logging.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 6002
🏁 Script executed:
sed -n '1450,1545p' flashinfer/api_logging.py | head -100Repository: flashinfer-ai/flashinfer
Length of output: 4750
🏁 Script executed:
grep -n "def flashinfer_api\|def _attach_fi_trace" flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 169
🏁 Script executed:
python3 -m mypy flashinfer/api_logging.py --no-error-summary 2>&1 | grep -i "fi_trace\|attr-defined" | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
sed -n '1435,1460p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1278
🏁 Script executed:
sed -n '1557,1620p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 2910
🏁 Script executed:
python3 << 'EOF'
# Check what mypy would say about direct assignment to Callable
import sys
import tempfile
import os
code = '''
from typing import Callable
def my_decorator(original: Callable) -> Callable:
wrapped: Callable = original
wrapped.fi_trace = lambda: None # This is the issue
return wrapped
'''
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(code)
temp_file = f.name
try:
os.system(f'python3 -m mypy {temp_file} 2>&1 | head -20')
finally:
os.unlink(temp_file)
EOFRepository: flashinfer-ai/flashinfer
Length of output: 106
🏁 Script executed:
grep -n "fi_trace" flashinfer/api_logging.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 1728
🏁 Script executed:
grep -n "class.*Protocol\|@.*Protocol" flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
grep -B 5 "def _attach_fi_trace" flashinfer/api_logging.py | head -10Repository: flashinfer-ai/flashinfer
Length of output: 287
🏁 Script executed:
sed -n '1400,1445p' flashinfer/api_logging.pyRepository: flashinfer-ai/flashinfer
Length of output: 1579
🏁 Script executed:
python3 << 'EOF'
# Check if setattr vs direct assignment matters for mypy
code1 = '''
from typing import Callable
def decorator(func: Callable) -> Callable:
func.fi_trace = lambda: None # Direct assignment - mypy error
return func
'''
code2 = '''
from typing import Callable
def decorator(func: Callable) -> Callable:
setattr(func, "fi_trace", lambda: None) # setattr - mypy OK
return func
'''
print("Direct assignment would trigger attr-defined error on Callable types")
print("setattr() is the mypy-friendly approach for dynamic attributes")
EOFRepository: flashinfer-ai/flashinfer
Length of output: 199
🏁 Script executed:
grep -C 3 "wrapped: Callable\|original: Callable" flashinfer/api_logging.py | head -30Repository: flashinfer-ai/flashinfer
Length of output: 230
Use setattr() for dynamic attribute assignment to avoid mypy attr-defined errors.
Direct assignment of .fi_trace to Callable-typed variables triggers mypy's attr-defined check. Replace direct assignment with setattr() to explicitly model the dynamic attribute:
Required changes
- wrapped.fi_trace = fi_trace_fn
+ setattr(wrapped, "fi_trace", fi_trace_fn)
@@
- _auto_dump_wrapper.fi_trace = fi_trace_fn
+ setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn)
@@
- wrapped.fi_trace = build_fi_trace_fn(spec)
+ setattr(wrapped, "fi_trace", build_fi_trace_fn(spec))🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 1508-1539: mypy error(s): "Callable[..., Any]" has no attribute "fi_trace" / "_Wrapped[...]" has no attribute "fi_trace" [attr-defined]
🪛 Ruff (0.15.9)
[error] 1526-1527: try-except-pass detected, consider logging the exception
(S110)
[warning] 1526-1526: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/api_logging.py` around lines 1508 - 1531, Replace direct attribute
assignments to .fi_trace with setattr to avoid mypy attr-defined errors: where
the diff sets wrapped.fi_trace = fi_trace_fn and _auto_dump_wrapper.fi_trace =
fi_trace_fn, change those direct assignments to use setattr(wrapped, "fi_trace",
fi_trace_fn) and setattr(_auto_dump_wrapper, "fi_trace", fi_trace_fn). Keep the
same semantics (assign the fi_trace_fn callable) and leave other code in
_auto_dump_wrapper, _sig, and fi_trace_fn unchanged.
| # Import the old implementation from the trace package for backwards compat. | ||
| from .trace.template import ( # noqa: PLC0415 | ||
| Const, | ||
| Scalar, | ||
| Tensor, | ||
| TraceTemplate, | ||
| Var, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, check if the file exists and examine the relevant lines
if [ -f "flashinfer/fi_trace.py" ]; then
echo "=== File exists, examining lines 100-115 ==="
sed -n '100,115p' flashinfer/fi_trace.py | cat -n
else
echo "File not found"
fiRepository: flashinfer-ai/flashinfer
Length of output: 675
🏁 Script executed:
# Search for usages of the imported names in the file
echo "=== Searching for usages of Const, Scalar, Tensor, TraceTemplate, Var ==="
for name in Const Scalar Tensor TraceTemplate Var; do
echo "--- Searching for '$name' ---"
rg "\b${name}\b" flashinfer/fi_trace.py -n | head -20
doneRepository: flashinfer-ai/flashinfer
Length of output: 732
🏁 Script executed:
# Check the full file to understand the import context
echo "=== Counting lines and examining structure ==="
wc -l flashinfer/fi_trace.py
echo ""
echo "=== Checking for function definitions ==="
rg "^def " flashinfer/fi_trace.py -nRepository: flashinfer-ai/flashinfer
Length of output: 332
🏁 Script executed:
# Examine the build_fi_trace_fn function in detail to see if these imports are used
echo "=== Examining build_fi_trace_fn function (lines 97-237) ==="
sed -n '97,237p' flashinfer/fi_trace.py | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2919
🏁 Script executed:
# Check if there are any references to these imports after line 110
echo "=== Checking if imports are used after definition ==="
sed -n '110,237p' flashinfer/fi_trace.py | grep -E "Const|Scalar|Tensor[^_]|TraceTemplate|Var" | grep -v "torch.Tensor" | grep -v "isinstance"Repository: flashinfer-ai/flashinfer
Length of output: 121
🏁 Script executed:
# Verify the file can be linted - check if Ruff would report F401
echo "=== Checking for Ruff/linting setup ==="
if command -v ruff &> /dev/null; then
echo "Ruff is available, checking for F401 violations..."
ruff check flashinfer/fi_trace.py --select F401 2>&1 | head -20
else
echo "Ruff not available in sandbox"
fiRepository: flashinfer-ai/flashinfer
Length of output: 1078
Remove unused template imports from the legacy builder.
The names Const, Scalar, Tensor, TraceTemplate, and Var imported at lines 105–109 are not referenced anywhere in build_fi_trace_fn. Ruff F401 violations are already flagged on these lines.
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 105-105: ruff (F401): .trace.template.Const imported but unused
[error] 106-106: ruff (F401): .trace.template.Scalar imported but unused
[error] 107-107: ruff (F401): .trace.template.Tensor imported but unused
[error] 108-108: ruff (F401): .trace.template.TraceTemplate imported but unused
[error] 109-109: ruff (F401): .trace.template.Var imported but unused
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fi_trace.py` around lines 103 - 110, The import line bringing in
Const, Scalar, Tensor, TraceTemplate, and Var from .trace.template is unused in
build_fi_trace_fn and causing Ruff F401 warnings; remove those five names (or
the whole legacy import if nothing else from that module is used) so only needed
symbols remain imported in flashinfer/fi_trace.py and eliminate the unused
imports Const, Scalar, Tensor, TraceTemplate, Var from the import statement that
currently appears alongside build_fi_trace_fn.
| def fi_trace( | ||
| func_or_method: Callable, | ||
| save_dir: Optional[Union[str, Path]] = None, | ||
| **kwargs: Any, | ||
| ) -> Dict[str, Any]: | ||
| """Generate a flashinfer-bench definition JSON for any FlashInfer API call. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| func_or_method: | ||
| A ``@flashinfer_api``-decorated function or (bound) method. | ||
| save_dir: | ||
| Directory where the JSON definition file should be written. | ||
| Falls back to ``FLASHINFER_TRACE_DUMP_DIR`` env-var when *None*. | ||
| **kwargs: | ||
| The same tensor arguments you would pass to the real API. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| A flashinfer-bench compatible definition dictionary. | ||
|
|
||
| Examples | ||
| -------- | ||
| Standalone function:: | ||
|
|
||
| defn = fi_trace(flashinfer.norm.rmsnorm, input=hidden, weight=weight) | ||
|
|
||
| Bound method (instance.run):: | ||
|
|
||
| defn = fi_trace(wrapper.run, q=q_tensor, paged_kv_cache=(k, v)) | ||
|
|
||
| Class-level (unbound):: | ||
|
|
||
| defn = fi_trace( | ||
| flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper.run, | ||
| q=q_tensor, paged_kv_cache=(k, v), | ||
| ) | ||
| """ | ||
| actual_func = getattr(func_or_method, "__func__", func_or_method) | ||
| trace_fn = getattr(actual_func, "fi_trace", None) | ||
| if trace_fn is None: | ||
| qualname = getattr(actual_func, "__qualname__", repr(actual_func)) | ||
| raise ValueError( | ||
| f"No fi_trace spec is registered for '{qualname}'. " | ||
| "Only @flashinfer_api(trace=...)-decorated functions support fi_trace." | ||
| ) | ||
| return trace_fn(save_dir=save_dir, **kwargs) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's find and examine the fi_trace.py file
find . -name "fi_trace.py" -type fRepository: flashinfer-ai/flashinfer
Length of output: 92
🏁 Script executed:
# Count lines in the file to understand its size
wc -l flashinfer/fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 94
🏁 Script executed:
# Read the function in question with context
sed -n '230,295p' flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 2241
🏁 Script executed:
# Check the imports at the top of the file to see what type annotations are available
head -40 flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 1954
🏁 Script executed:
# Search for any Protocol definitions or type hints related to fi_trace in the codebase
rg "Protocol|fi_trace" flashinfer/*.py --type py -B 2 -A 2 | head -100Repository: flashinfer-ai/flashinfer
Length of output: 5252
🏁 Script executed:
# Check if there's a pyproject.toml or setup.cfg with mypy configuration
find . -name "pyproject.toml" -o -name "setup.cfg" -o -name "mypy.ini" | head -5Repository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
# Read full imports section of fi_trace.py
head -80 flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 3516
🏁 Script executed:
# Check pyproject.toml for mypy configuration
cat pyproject.toml | grep -A 20 "tool.mypy" | head -30Repository: flashinfer-ai/flashinfer
Length of output: 429
🏁 Script executed:
# Search for existing Protocol definitions in the codebase
rg "from typing import|Protocol" flashinfer/*.py --type py -B 1 -A 1 | head -80Repository: flashinfer-ai/flashinfer
Length of output: 2994
🏁 Script executed:
# Check if there are any type-related issues or comments in fi_trace.py
rg "type:|Protocol|HasAttr" flashinfer/fi_trace.py --type py -B 2 -A 2Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Verify the exact issue by looking at the complete function implementation once more
sed -n '238,256p' flashinfer/fi_trace.py | cat -nRepository: flashinfer-ai/flashinfer
Length of output: 792
🏁 Script executed:
# Check if there's a typing.Protocol or typing_extensions.Protocol import anywhere
rg "^from typing|^import typing" flashinfer/fi_trace.pyRepository: flashinfer-ai/flashinfer
Length of output: 123
Make the .fi_trace attribute requirement explicit in the type signature.
The function expects func_or_method (or its __func__) to have a .fi_trace attribute, but the signature declares plain Callable. Define a Protocol for this contract (e.g., class TracedCallable(Protocol): def fi_trace(...) -> Dict[str, Any]: ...) or use a cast when resolving the actual function. This will satisfy mypy and make the requirement clear to callers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/fi_trace.py` around lines 238 - 285, The function fi_trace
currently types func_or_method as Callable but relies on the object
(actual_func) exposing a .fi_trace attribute; update the typing to make that
contract explicit by introducing a Protocol (e.g., TracedCallable with a
fi_trace(self, save_dir: Optional[Union[str, Path]] = None, **kwargs) ->
Dict[str, Any]) and use that Protocol as the type for func_or_method (or cast
actual_func to TracedCallable before accessing .fi_trace); ensure the Protocol
signature matches how trace_fn is called in fi_trace and import typing.Protocol
and any necessary types so mypy recognizes the requirement.
| "initial_state": Tensor( | ||
| ["pool_size", "num_v_heads", "head_size", "head_size"], | ||
| description="Initial recurrent state pool in k-last layout [pool_size, H, V, K].", | ||
| ), | ||
| "initial_state_indices": Tensor( | ||
| ["batch_size"], | ||
| description="Indices mapping each batch to its initial state in the pool.", | ||
| ), | ||
| "A_log": Tensor( | ||
| ["num_v_heads"], | ||
| description="Log decay parameter (learnable). Used to compute g = exp(-exp(A_log) * softplus(a + dt_bias)).", | ||
| ), | ||
| "a": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads"], | ||
| description="Input-dependent decay from projection.", | ||
| ), | ||
| "dt_bias": Tensor( | ||
| ["num_v_heads"], | ||
| description="Decay bias (learnable). Added to 'a' before softplus.", | ||
| ), | ||
| "b": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads"], | ||
| description="Update gate input from projection. beta = sigmoid(b).", | ||
| ), | ||
| "scale": Scalar( | ||
| "float32", | ||
| description="Scale factor. Default is 1/sqrt(head_size).", | ||
| ), | ||
| "intermediate_states_buffer": Tensor( | ||
| ["pool_size", "seq_len", "num_v_heads", "head_size", "head_size"], | ||
| optional=True, | ||
| description="Optional buffer for caching intermediate states for potential rollback.", | ||
| ), | ||
| }, | ||
| outputs={ | ||
| "output": Tensor( | ||
| ["batch_size", "seq_len", "num_v_heads", "head_size"], | ||
| dtype_from="q", | ||
| description="Attention output for all T tokens. Shape follows num_v_heads in GVA mode.", | ||
| ), | ||
| "final_state": Tensor( | ||
| ["pool_size", "num_v_heads", "head_size", "head_size"], | ||
| dtype="float32", | ||
| description="Updated recurrent state pool in k-last layout [pool_size, H, V, K]. Unchanged if disable_state_update=True.", | ||
| ), |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, find the gated_delta_rule_mtp function definition
rg "def gated_delta_rule_mtp" -A 20Repository: flashinfer-ai/flashinfer
Length of output: 3167
🏁 Script executed:
# Also look at the template file to understand its full context
fd "gdn.py" --type fRepository: flashinfer-ai/flashinfer
Length of output: 123
🏁 Script executed:
# Check for references to disable_state_update in the codebase
rg "disable_state_update" -A 2 -B 2Repository: flashinfer-ai/flashinfer
Length of output: 49287
🏁 Script executed:
# Get the full template definition including inputs
sed -n '450,550p' flashinfer/trace/templates/gdn.pyRepository: flashinfer-ai/flashinfer
Length of output: 4136
🏁 Script executed:
# Look for where disable_state_update might be defined elsewhere in the template
rg "disable_state_update" flashinfer/trace/templates/gdn.py -B 5 -A 5Repository: flashinfer-ai/flashinfer
Length of output: 558
🏁 Script executed:
# Check if there are other GDN-related templates that might model disable_state_update
fd "gdn" flashinfer/trace/templates/ -type fRepository: flashinfer-ai/flashinfer
Length of output: 238
🏁 Script executed:
# Look for all trace templates in the gdn area
find flashinfer/trace/templates -name "*gdn*" -o -name "*delta*"Repository: flashinfer-ai/flashinfer
Length of output: 101
🏁 Script executed:
# Check what trace templates exist
ls -la flashinfer/trace/templates/Repository: flashinfer-ai/flashinfer
Length of output: 598
🏁 Script executed:
# Check the gdn_decode.py function signature to see all parameters
sed -n '1,150p' flashinfer/gdn_decode.py | grep -A 50 "def gated_delta_rule_mtp"Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Get the exact parameter list from gdn_decode.py
grep -A 15 "def gated_delta_rule_mtp" flashinfer/gdn_decode.py | head -20Repository: flashinfer-ai/flashinfer
Length of output: 586
Add disable_state_update as a template input.
The gated_delta_rule_mtp() function supports a disable_state_update parameter that controls whether final_state is updated, yet the template does not expose it as an input. While the output description correctly mentions this behavior, the schema omission causes fi_trace to emit identical specifications for both state-updating and non-updating modes.
Add disable_state_update as a boolean input (optional or required, as per the function's design) to accurately model the two distinct operational modes.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/trace/templates/gdn.py` around lines 502 - 546, The template
schema is missing the disable_state_update input required by
gated_delta_rule_mtp, so add a boolean Tensor/Scalar entry named
"disable_state_update" to the inputs dict (matching how other flags are
represented) and mark it optional or required consistent with
gated_delta_rule_mtp's signature; ensure you reference the symbol name
"disable_state_update" and update the inputs block near the existing
"initial_state"/"final_state" entries so the trace can distinguish
state-updating vs non-updating behavior described in "final_state".
| def test_fi_trace_complete_gqa_paged_decode(): | ||
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | ||
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | ||
| from flashinfer.trace.templates.attention import gqa_paged_decode_trace | ||
|
|
||
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | ||
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | ||
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | ||
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | ||
|
|
||
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) |
There was a problem hiding this comment.
Remove unused import gqa_paged_decode_trace.
The import gqa_paged_decode_trace at line 387 is flagged as unused by pre-commit. The test only uses BatchDecodeWithPagedKVCacheWrapper.run.fi_trace.
Suggested fix
def test_fi_trace_complete_gqa_paged_decode():
"""GQA paged decode: tuple paged_kv_cache input handled correctly."""
from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper
- from flashinfer.trace.templates.attention import gqa_paged_decode_trace
B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_fi_trace_complete_gqa_paged_decode(): | |
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | |
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | |
| from flashinfer.trace.templates.attention import gqa_paged_decode_trace | |
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | |
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | |
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) | |
| def test_fi_trace_complete_gqa_paged_decode(): | |
| """GQA paged decode: tuple paged_kv_cache input handled correctly.""" | |
| from flashinfer.decode import BatchDecodeWithPagedKVCacheWrapper | |
| B, H, KV, D, P, NP = 4, 8, 4, 64, 16, 8 | |
| q = torch.zeros(B, H, D, dtype=torch.bfloat16) | |
| k = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| v = torch.zeros(NP, P, KV, D, dtype=torch.bfloat16) | |
| defn = BatchDecodeWithPagedKVCacheWrapper.run.fi_trace(q=q, paged_kv_cache=(k, v)) |
🧰 Tools
🪛 GitHub Actions: pre-commit
[error] 387-387: ruff (F401): flashinfer.trace.templates.attention.gqa_paged_decode_trace imported but unused
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/trace/test_fi_trace_template_consistency.py` around lines 384 - 394,
Remove the unused import gqa_paged_decode_trace from the test; locate the import
statement that reads "from flashinfer.trace.templates.attention import
gqa_paged_decode_trace" and delete it so the test only imports and uses
BatchDecodeWithPagedKVCacheWrapper.run.fi_trace (ensure no other references to
gqa_paged_decode_trace remain in the file).
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
flashinfer/api_logging.py (1)
1510-1531:⚠️ Potential issue | 🟠 Major
FLASHINFER_LOGLEVEL=0no longer preserves the zero-overhead path for traced APIs.For any
@flashinfer_api(trace=...)function, this still returns_attach_fi_trace(...), and_attach_fi_trace()always builds_auto_dump_wrapper. That means every call pays an extra Python frame and_is_trace_dump_enabled()check even when logging is disabled, which contradicts the decorator's documented “returns original function” contract.Also applies to: 1629-1634
tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json (1)
126-126:⚠️ Potential issue | 🟠 MajorFlatten the paged cache before the reference matmul.
The schema above says
ckv_cache/kpe_cacheare[num_pages, page_size, head_dim_*], so withpage_size=64thesqueeze(1)calls do nothing.Kc_all[tok_idx]/Kp_all[tok_idx]therefore stay 3D, and the laterqn @ Kc.Tpath no longer matches the intended[num_qo_heads, L]attention score computation. Reshape the caches to token-major 2D tensors before indexing, or rewrite the reference to handle paged tensors directly.For PyTorch 2.x, if ckv_cache has shape [num_pages, page_size, head_dim], what does ckv_cache.squeeze(1) return when page_size=64, and what shape does qn @ Kc.T use when qn is [num_qo_heads, head_dim] and Kc is [L, 64, head_dim]?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json` at line 126, The reference incorrectly uses ckv_cache.squeeze(1) / kpe_cache.squeeze(1) which leaves a 3D paged tensor when page_size>1; in _mla_paged_decode_reference you must flatten the page-major caches to token-major 2D tensors (num_pages*page_size, head_dim_ckv/kpe) before indexing (i.e., compute Kc_all and Kp_all as [num_tokens, head_dim_*] rather than [num_pages, page_size, head_dim_*]), then select Kc/Kp with tok_idx so that qn @ Kc.T and qp @ Kp.T produce [num_qo_heads, L] logits; update the Kc_all/Kp_all creation near their assignments and ensure subsequent uses (Kc, Kp, logits, output) operate on the flattened shapes.tests/trace/example.py (1)
54-551:⚠️ Potential issue | 🟠 MajorPytest still won't execute this trace generator.
tests/trace/example.pyis still a standalone script with top-level side effects and notest_*entrypoint, so CI won't collect it or validate the generated fixtures. Please move the body into a test/helper and keep a__main__guard only for manual runs.As per coding guidelines,
tests/**/*.py: Prefix test functions withtest_and structure tests by feature intests/subdirectories matching kernel categories.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/example.py` around lines 54 - 551, The file currently executes trace-generation at import (top-level side effects: SAVE_DIR, the whole sequence of flashinfer.* calls, and the final files/print summary), so move the entire body into a callable function (e.g. generate_fi_traces or build_example_traces) that encapsulates planning/running wrappers and the final JSON-summary logic, then add a pytest entrypoint test_example_traces() in the same module (or a new tests/ submodule) that calls that function and asserts expected output (e.g. presence/count of files from SAVE_DIR or that no exceptions occur), and retain an if __name__ == "__main__": guard to call generate_fi_traces() for manual runs; reference SAVE_DIR, the trace-generation sequence (all flashinfer.* calls and wrapper usages like BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper, BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the changes.
🧹 Nitpick comments (8)
flashinfer/trace/templates/attention.py (3)
28-41: Prefix unused unpacked variables with underscore.
page_sizeis unpacked fromk_cache.shapebut never used in the function. Prefix with_to satisfy linter.Proposed fix
def _gqa_paged_decode_reference(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale): batch_size, num_qo_heads, head_dim = q.shape - _, page_size, num_kv_heads, _ = k_cache.shape + _, _page_size, num_kv_heads, _ = k_cache.shape🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 28 - 41, The variable page_size unpacked in _gqa_paged_decode_reference is unused and should be prefixed with an underscore to satisfy the linter; change the unpacking from "_, page_size, num_kv_heads, _ = k_cache.shape" style to use "_page_size" (or simply "_" if preferred) so the function signature still captures batch dimensions but removes the unused symbol while keeping references to q, k_cache, v_cache, kv_indptr, kv_indices, and sm_scale intact.
244-248: Prefix unused unpacked variable with underscore.
total_kvis unpacked but never used. Prefix with_to satisfy linter.Proposed fix
def _gqa_ragged_prefill_reference(q, k, v, qo_indptr, kv_indptr, sm_scale): total_q, num_qo_heads, head_dim = q.shape - total_kv, num_kv_heads, _ = k.shape + _total_kv, num_kv_heads, _ = k.shape len_indptr = qo_indptr.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 244 - 248, In _gqa_ragged_prefill_reference, the unpacking assigns total_kv from k.shape but it's unused; change the unpacked name to _total_kv (or prefix with underscore) in the q, k, v shape assignment so the linter recognizes it as intentionally unused (i.e., update the tuple unpack on the line with "total_q, num_qo_heads, head_dim = q.shape" / "total_kv, num_kv_heads, _ = k.shape" to use _total_kv).
125-144: Prefix unused unpacked variables with underscore.Both
num_pagesandpage_sizeare unpacked but never used. This triggers ruff RUF059.Proposed fix
def _gqa_paged_prefill_reference( q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale ): total_q, num_qo_heads, head_dim = q.shape - num_pages, page_size, num_kv_heads, _ = k_cache.shape + _num_pages, _page_size, num_kv_heads, _ = k_cache.shape len_indptr = qo_indptr.shape[0]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/attention.py` around lines 125 - 144, The variables num_pages and page_size are unpacked in _gqa_paged_prefill_reference but never used, triggering ruff RUF059; update the tuple unpacking to prefix unused names with an underscore (e.g., _num_pages, _page_size or simply _, _) in the line that unpacks k_cache.shape so the intent is clear and the linter warning is resolved while leaving k_cache usage unchanged.tests/trace/test_fi_trace_template_consistency.py (3)
440-466: Consider renamingItoINTERMEDIATEorINTER_SIZEfor clarity.The variable name
Iis flagged as ambiguous (E741) because it can be confused with1orl. The same applies to line 488.Proposed fix
- T, E, EL, H, I, BS = 4, 16, 2, 256, 64, 128 + T, E, EL, H, INTER, BS = 4, 16, 2, 256, 64, 128 defn = trtllm_fp8_block_scale_moe.fi_trace( routing_logits=torch.zeros(T, E, dtype=torch.float32), routing_bias=torch.zeros(E, dtype=torch.bfloat16), hidden_states=torch.zeros(T, H, dtype=torch.float8_e4m3fn), hidden_states_scale=torch.ones(H // BS, T, dtype=torch.float32), - gemm1_weights=torch.zeros(EL, 2 * I, H, dtype=torch.float8_e4m3fn), - gemm1_weights_scale=torch.ones(EL, (2 * I) // BS, H // BS, dtype=torch.float32), - gemm2_weights=torch.zeros(EL, H, I, dtype=torch.float8_e4m3fn), - gemm2_weights_scale=torch.ones(EL, H // BS, I // BS, dtype=torch.float32), + gemm1_weights=torch.zeros(EL, 2 * INTER, H, dtype=torch.float8_e4m3fn), + gemm1_weights_scale=torch.ones(EL, (2 * INTER) // BS, H // BS, dtype=torch.float32), + gemm2_weights=torch.zeros(EL, H, INTER, dtype=torch.float8_e4m3fn), + gemm2_weights_scale=torch.ones(EL, H // BS, INTER // BS, dtype=torch.float32), num_experts=E, top_k=top_k, - intermediate_size=I, + intermediate_size=INTER,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 440 - 466, Rename the ambiguous single-letter variable I to a clearer name like INTERMEDIATE or INTER_SIZE throughout the test (e.g., the variable declaration T, E, EL, H, INTERMEDIATE, BS = ... and all uses: gemm1_weights shape (EL, 2 * INTERMEDIATE, H), gemm2_weights shape (EL, H, INTERMEDIATE), and intermediate_size=INTERMEDIATE in the trtllm_fp8_block_scale_moe.fi_trace(...) call) and similarly update any other occurrence on the nearby line 488 so all references remain consistent.
374-376: Rename ambiguous variableltolabelfor clarity.The single-letter
lcan be confused with1orI. Use a more descriptive name.Proposed fix
-_E2E_PAIRS = [(f, t, l) for f, t, l in _ALL_PAIRS if l not in _E2E_SKIP] -_E2E_IDS = [label for _, _, label in _E2E_PAIRS] +_E2E_PAIRS = [(f, t, label) for f, t, label in _ALL_PAIRS if label not in _E2E_SKIP] +_E2E_IDS = [lbl for _, _, lbl in _E2E_PAIRS]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 374 - 376, Rename the ambiguous single-letter variable in the list comprehension: change the unpacking in _E2E_PAIRS from (f, t, l) to (f, t, label) and update the filter to use label instead of l; also update _E2E_IDS to unpack/use the same name (e.g., [label for _, _, label in _E2E_PAIRS]) so all references use the descriptive symbol label while keeping the existing logic with _ALL_PAIRS and _E2E_SKIP.
560-563: Use raw string for regex pattern with metacharacters.The pattern contains regex metacharacters (
=,*,.) but is not a raw string. While it works due to no escape conflicts, usingr"..."is safer and clearer.Proposed fix
- with pytest.raises(AssertionError, match="param=.*hidden_state.*not found"): + with pytest.raises(AssertionError, match=r"param=.*hidden_state.*not found"): assert_template_signature_consistency(func, broken, label="meta-test")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/trace/test_fi_trace_template_consistency.py` around lines 560 - 563, The regex string passed to pytest.raises in the test (match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in the pytest.raises call that wraps assert_template_signature_consistency(func, broken, label="meta-test") to ensure backslashes and metacharacters are interpreted correctly; update the test invocation around _make_gdn_decode_func(), func, and broken accordingly.flashinfer/trace/templates/moe.py (2)
674-678: Replace ambiguous×(multiplication sign) with ASCIIx.The Unicode multiplication sign
×(U+00D7) can cause confusion. Use ASCIIxor*instead.Proposed fix
"gemm1_out_size": Const( - description="Output size of FC1 (2 × intermediate_size for SwiGLU).", + description="Output size of FC1 (2 * intermediate_size for SwiGLU).", abbrev="", ),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 674 - 678, The description string for the Const named "gemm1_out_size" contains a Unicode multiplication sign (`×`); update the description in the gemm1_out_size Const to use an ASCII "x" (or "*" if you prefer) instead (e.g., change "2 × intermediate_size for SwiGLU" to "2 x intermediate_size for SwiGLU") so the comment uses plain ASCII characters.
795-806: FP4 MoE templates cannot be validated against reference implementations.All FP4 templates pass
reference=Nonebecause the_make_standard_fp4_moe_tracefactory does not accept a reference parameter (unlike the FP8 factory). No FP4 MoE reference implementations are defined. Given that FP4 templates are marked asstatus:experimental, either implement reference functions or document why validation is deferred.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/trace/templates/moe.py` around lines 795 - 806, The FP4 MoE factory _make_standard_fp4_moe_trace currently hardcodes reference=None so FP4 templates cannot be validated; update the factory signature to accept an optional reference parameter (e.g., reference=None) and pass that through into TraceTemplate(reference=reference), then update all call sites that construct FP4 MoE traces to provide a proper reference function or explicitly pass None with a comment; additionally, either implement the missing FP4 MoE reference functions (and register them where other references live) or add clear documentation in the template module explaining that FP4 validation is deferred and why.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/api_logging.py`:
- Around line 1521-1527: The trace auto-dump currently swallows all exceptions
in the block guarded by _is_trace_dump_enabled(), making failures invisible;
change the except Exception: pass to catch Exception as e and emit a non-fatal
log (e.g., processLogger.warning or module logger) that includes the failing
trace function name (use fi_trace_fn.__name__ or _sig.signature info) and the
exception information (e) or traceback, then continue; update the try/except
around _sig.bind(*args, **kwargs), bound.apply_defaults(), and
fi_trace_fn(**dict(bound.arguments)) to log the diagnostic while keeping the
call non-fatal.
---
Duplicate comments:
In `@tests/trace/example.py`:
- Around line 54-551: The file currently executes trace-generation at import
(top-level side effects: SAVE_DIR, the whole sequence of flashinfer.* calls, and
the final files/print summary), so move the entire body into a callable function
(e.g. generate_fi_traces or build_example_traces) that encapsulates
planning/running wrappers and the final JSON-summary logic, then add a pytest
entrypoint test_example_traces() in the same module (or a new tests/ submodule)
that calls that function and asserts expected output (e.g. presence/count of
files from SAVE_DIR or that no exceptions occur), and retain an if __name__ ==
"__main__": guard to call generate_fi_traces() for manual runs; reference
SAVE_DIR, the trace-generation sequence (all flashinfer.* calls and wrapper
usages like BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithRaggedKVCacheWrapper,
BatchMLAPagedAttentionWrapper, and the files variable/summary) when making the
changes.
In `@tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.json`:
- Line 126: The reference incorrectly uses ckv_cache.squeeze(1) /
kpe_cache.squeeze(1) which leaves a 3D paged tensor when page_size>1; in
_mla_paged_decode_reference you must flatten the page-major caches to
token-major 2D tensors (num_pages*page_size, head_dim_ckv/kpe) before indexing
(i.e., compute Kc_all and Kp_all as [num_tokens, head_dim_*] rather than
[num_pages, page_size, head_dim_*]), then select Kc/Kp with tok_idx so that qn @
Kc.T and qp @ Kp.T produce [num_qo_heads, L] logits; update the Kc_all/Kp_all
creation near their assignments and ensure subsequent uses (Kc, Kp, logits,
output) operate on the flattened shapes.
---
Nitpick comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 28-41: The variable page_size unpacked in
_gqa_paged_decode_reference is unused and should be prefixed with an underscore
to satisfy the linter; change the unpacking from "_, page_size, num_kv_heads, _
= k_cache.shape" style to use "_page_size" (or simply "_" if preferred) so the
function signature still captures batch dimensions but removes the unused symbol
while keeping references to q, k_cache, v_cache, kv_indptr, kv_indices, and
sm_scale intact.
- Around line 244-248: In _gqa_ragged_prefill_reference, the unpacking assigns
total_kv from k.shape but it's unused; change the unpacked name to _total_kv (or
prefix with underscore) in the q, k, v shape assignment so the linter recognizes
it as intentionally unused (i.e., update the tuple unpack on the line with
"total_q, num_qo_heads, head_dim = q.shape" / "total_kv, num_kv_heads, _ =
k.shape" to use _total_kv).
- Around line 125-144: The variables num_pages and page_size are unpacked in
_gqa_paged_prefill_reference but never used, triggering ruff RUF059; update the
tuple unpacking to prefix unused names with an underscore (e.g., _num_pages,
_page_size or simply _, _) in the line that unpacks k_cache.shape so the intent
is clear and the linter warning is resolved while leaving k_cache usage
unchanged.
In `@flashinfer/trace/templates/moe.py`:
- Around line 674-678: The description string for the Const named
"gemm1_out_size" contains a Unicode multiplication sign (`×`); update the
description in the gemm1_out_size Const to use an ASCII "x" (or "*" if you
prefer) instead (e.g., change "2 × intermediate_size for SwiGLU" to "2 x
intermediate_size for SwiGLU") so the comment uses plain ASCII characters.
- Around line 795-806: The FP4 MoE factory _make_standard_fp4_moe_trace
currently hardcodes reference=None so FP4 templates cannot be validated; update
the factory signature to accept an optional reference parameter (e.g.,
reference=None) and pass that through into TraceTemplate(reference=reference),
then update all call sites that construct FP4 MoE traces to provide a proper
reference function or explicitly pass None with a comment; additionally, either
implement the missing FP4 MoE reference functions (and register them where other
references live) or add clear documentation in the template module explaining
that FP4 validation is deferred and why.
In `@tests/trace/test_fi_trace_template_consistency.py`:
- Around line 440-466: Rename the ambiguous single-letter variable I to a
clearer name like INTERMEDIATE or INTER_SIZE throughout the test (e.g., the
variable declaration T, E, EL, H, INTERMEDIATE, BS = ... and all uses:
gemm1_weights shape (EL, 2 * INTERMEDIATE, H), gemm2_weights shape (EL, H,
INTERMEDIATE), and intermediate_size=INTERMEDIATE in the
trtllm_fp8_block_scale_moe.fi_trace(...) call) and similarly update any other
occurrence on the nearby line 488 so all references remain consistent.
- Around line 374-376: Rename the ambiguous single-letter variable in the list
comprehension: change the unpacking in _E2E_PAIRS from (f, t, l) to (f, t,
label) and update the filter to use label instead of l; also update _E2E_IDS to
unpack/use the same name (e.g., [label for _, _, label in _E2E_PAIRS]) so all
references use the descriptive symbol label while keeping the existing logic
with _ALL_PAIRS and _E2E_SKIP.
- Around line 560-563: The regex string passed to pytest.raises in the test
(match="param=.*hidden_state.*not found") uses metacharacters but isn't a raw
string; change it to a raw string (e.g., r"param=.*hidden_state.*not found") in
the pytest.raises call that wraps assert_template_signature_consistency(func,
broken, label="meta-test") to ensure backslashes and metacharacters are
interpreted correctly; update the test invocation around
_make_gdn_decode_func(), func, and broken accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4aacac64-d181-4f0b-a266-998fd025caf7
📥 Commits
Reviewing files that changed from the base of the PR and between 2f4aceb637fabe2f75c180b8086e97f602b147d0 and c2843a5e4a5c38d1cac9cf32c29db34aa7935f82.
📒 Files selected for processing (22)
flashinfer/api_logging.pyflashinfer/fi_trace.pyflashinfer/fused_moe/core.pyflashinfer/trace/templates/attention.pyflashinfer/trace/templates/moe.pytests/trace/example.pytests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.jsontests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps64.jsontests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.jsontests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp4_block_scale_topk_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_ds_routing_topk8_ng8_kg4_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_llama4_routing_topk1_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.jsontests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.jsontests/trace/test_fi_trace.pytests/trace/test_fi_trace_template_consistency.py
✅ Files skipped from review due to trivial changes (10)
- tests/trace/fi_trace_out/moe_fp4_block_scale_ds_routing_topk8_e32_h7168_i2048_ng8_kg4.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_llama4_routing_topk1_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_default_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp8_block_scale_topk_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/mla_paged_decode_h16_ckv512_kpe64_ps1.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_renormalize_naive_routing_topk8_e32_h7168_i2048.json
- tests/trace/fi_trace_out/moe_fp4_block_scale_default_routing_topk8_e32_h7168_i2048.json
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/fused_moe/core.py
Three CI failures on SM12x (Blackwell refresh):
1. test_trtllm_batch_context (and test_trtllm_batch_decode): the
TllmGenFmhaRunner kernel is only instantiated for SM100/SM103 and
raises "Unsupported architecture" at runtime on SM12x. Our prior
guard was _skip_if_not_sm100() which only rejects pre-Blackwell
GPUs (major < 10). Add _skip_if_not_sm100_or_103() which rejects
anything except (10, 0) and (10, 3), and apply it to both tests.
2. test_tgv_gemm_sm100: the kernel explicitly checks _match_sm_version
for "100" or "103" (see flashinfer/gemm/gemm_base.py:1140) and
raises ValueError on SM12x. Same narrowed guard applied.
3. test_xqa_mla: on SM12x the kernel runs correctly but ~0.1% of
output elements exceed the element-wise bf16 tolerance due to the
kernel's internal FP8 quantization of Q and the KV cache. Switch
to cosine similarity (cos_sim > 0.99), which is the standard FP8
correctness metric elsewhere in the repo (tests/gemm/test_mm_fp8.py).
Added a _close_fp8 helper for this.
On B200 (SM100): 56 passed, 4 skipped.
Expected on SM12x CI: 54 passed, 4 skipped, 2 extra skips for the
trtllm_batch_{decode,context} tests — no failures.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/bot run |
# Conflicts: # flashinfer/attention/_core.py
CI reports test_xqa_mla_reference_correctness failing on SM12x (5090,
RTX PRO 6000 Blackwell) with cos_sim=0.4576 — the signature of a
fundamentally different softmax distribution, not a quantization error.
Root cause: unlike the regular XQA kernel (which applies
``rsqrtf(validElemsPerHead)`` to the QK product internally — see
csrc/xqa/mha.cu:1765-1767 and mha_sm90.cu:781-783), the MLA variant
does NOT. csrc/xqa/mla_sm120.cu:456 computes:
qkScaleLog2e = qScale * kvCacheScale * log2e
with no ``1/sqrt(head_dim)`` factor — the MLA kernel leaves that
scaling to the caller (it's typically absorbed into q_scale at model
load time). The existing tests/attention/test_xqa.py mirrors this by
passing ``q_scale * math.sqrt(576)`` to its reference so the two
sqrt(576) factors cancel.
The trace reference I wrote used ``sm_scale = 1/sqrt(head_dim)``
unconditionally, copying the regular-XQA pattern. On random inputs
this produced nearly uniform attention (tiny logits at scale ~0.042)
while the kernel produced sharp attention (logits at scale 1.0) — two
completely different distributions, hence cos_sim ~= 0.46.
Fix: drop the 1/sqrt(head_dim) factor in _xqa_mla_reference and
compute ``qk_scale = q_scale * kv_scale`` to match the kernel.
Also read V from ``v_cache`` (the dedicated buffer) rather than
slicing ``k_cache``, so the reference stays correct when V and K are
genuinely separate tensors (as in the existing test). The test still
passes K as V to the kernel, so the consistency check is preserved.
Verified on B200: all 354 tests/trace tests pass (xqa_mla skips, gated
to SM12x). SM12x-specific verification happens in CI.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Audit trigger: after fixing _xqa_mla_reference to not apply
1/sqrt(head_dim) (since the MLA kernel doesn't either), I audited every
other attention reference in flashinfer/trace/templates/ to confirm none
of them carry the same mistake. The audit found no additional bugs:
- _single_decode_reference / _single_prefill_reference:
default sm_scale = 1/sqrt(head_dim), matches the Python wrapper's
default (flashinfer/decode.py, flashinfer/prefill.py).
- _gqa_paged_{decode,prefill}_reference / _gqa_ragged_prefill_reference /
_mla_paged_{decode,prefill}_reference / _dsa_paged_reference:
take sm_scale as a required argument and apply it as-is; matches the
flashinfer kernel path which treats sm_scale as opaque
(include/flashinfer/attention/mla.cuh:40 —
sm_scale_log2 = params.sm_scale * log2e, no rsqrt).
- _batch_attention_run_reference / _pod_run_reference /
_block_sparse_run_reference:
default sm_scale = 1/sqrt(head_dim), matches the decode/prefill
wrappers they delegate to.
- _trtllm_paged_attention_reference / _trtllm_fmha_v2_prefill_reference /
_cudnn_batch_{decode,prefill}_reference: take scale as a required
argument (bmm1_scale / scale); the test harness always passes the
pre-computed sm_scale * q_scale * k_scale.
- _xqa_reference: applies q_scale * kv_scale * rsqrtf(head_dim)
to match the regular-XQA kernel (csrc/xqa/mha.cu:1765,
mha_sm90.cu:781) which does the rsqrt internally.
While auditing _xqa_reference I noticed its signature previously ignored
q_scale / kv_scale (they were silently dropped via **_unused) — the kernel
takes them explicitly. Thread them through so the reference stays correct
if a caller ever passes non-default scales; default remains 1.0 and the
existing test keeps passing (56 passed, 4 skipped on B200).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/bot run |
…_fp8, allreduce_fusion Wires @flashinfer_api(trace=...) on four public APIs that sglang/vllm hit during DeepSeek R1 inference but that the auto-dump had previously skipped: - flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla (new trtllm_batch_decode_mla_trace in attention.py; reference models MLA attention as softmax(Q_nope @ K_nope.T + Q_pe @ K_pe.T) * bmm1_scale then attn @ K_nope * bmm2_scale, returning [B, q_len, H, kv_lora_rank]) - flashinfer.rope.rope_quantize_fp8 and flashinfer.rope.mla_rope_quantize_fp8 (new rope_quantize_fp8_trace and mla_rope_quantize_fp8_trace in rope.py; reference applies RoPE to the rotary half then per-tensor FP8 e4m3 quantize, matching tests/attention/test_rope.py atol=1e-2 rtol=2e-1) - flashinfer.comm.allreduce_fusion (new comm.py template; single-rank reference models the fusion side — pattern 0 passthrough and pattern 1 AR+Residual+RMSNorm — and raises NotImplementedError for quantized / MoE patterns that are multi-rank-only) Unit-test coverage (tests/trace/test_reference_correctness.py): - test_rope_quantize_fp8_reference_correctness, test_mla_rope_quantize_fp8_reference_correctness use assert_close(atol=1e-2, rtol=2e-1) matching tests/attention/test_rope.py's rope_quantize_fp8 coverage. - test_trtllm_batch_decode_mla_reference_correctness uses (atol=1e-2, rtol=1e-2) matching tests/attention/test_cute_dsl_mla_decode.py. Also audits every closeness check in test_reference_correctness.py against the kernel's own unit test for standard compliance: - RoPE (apply_rope*): _ROPE_TOL = (1e-2, 1e-2), matches bf16 rope test - RMSNorm / LayerNorm / gemma_*: 1e-3, matches tests/utils/test_norm.py - silu_and_mul / gelu_and_mul / gelu_tanh_and_mul: 1e-3, matches tests/utils/test_activation.py - single_decode / single_prefill: 1e-2, matches tests/attention/test_single_prefill.py - trtllm_batch_decode / context / fmha_v2_prefill: 1e-2, matches tests/attention/test_fmha_v2_prefill.py - cudnn_batch_decode / prefill: 1e-2, matches tests/attention/test_cudnn_*.py - block_sparse / var_block_sparse: 1e-2, matches tests/attention/test_block_sparse.py - batch_attention / multi_level_cascade: 1e-2, matches tests/attention/test_batch_attention.py - pod_with_paged_kv_cache / batch_pod: 1e-3, matches tests/utils/test_pod_kernels.py - segment_gemm: (2e-3, 1e-3), matches tests/gemm/test_group_gemm.py - cutlass_fused_moe: 1e-2, matches tests/moe/test_trtllm_cutlass_fused_moe.py - merge_state(s) / merge_state_in_place: 1e-3 on fp16 V, matches tests/attention/test_shared_prefix_kernels.py (switched inputs to fp16) - xqa / xqa_mla: new _close_pass_ratio helper with (atol=0.05, rtol=0.05, pass_ratio=0.98/0.95) matching tests/attention/test_xqa.py's pass-ratio check instead of cosine similarity (only tgv_gemm_sm100 still uses cos_sim because tests/gemm/test_tgv_gemm.py uses cos_sim itself) - test_tgv_gemm_sm100 gated to SM100-only (cubin absent on SM103) Also updates tests/trace/example.py with example calls that trigger auto-dump for the three new traces. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…pe_quantize_append Wires @flashinfer_api(trace=...) on four DeepSeek-R1-path APIs that were previously running un-traced: - flashinfer.decode.xqa_batch_decode_with_kv_cache (new xqa_batch_decode_trace; reuses _trtllm_paged_attention_reference since the math is the same — XQA is a different backend for the same paged decode op) - flashinfer.mla.xqa_batch_decode_with_kv_cache_mla (new xqa_batch_decode_mla_trace sharing _trtllm_batch_decode_mla_reference since the math matches the trtllm-gen MLA decode) - flashinfer.concat_ops.concat_mla_k (new concat_mla_k_trace; reference broadcasts k_rope across heads and writes [k_nope ‖ k_rope] in-place — exact-copy comparison) - flashinfer.rope.rope_quantize_fp8_append_paged_kv_cache (new rope_quantize_fp8_append_paged_kv_cache_trace; reference applies RoPE → FP8 quantize then appends K/V into the paged cache. Test compares the returned Q outputs only — paged K/V append uses an implementation-specific internal layout) Unit-test coverage in tests/trace/test_reference_correctness.py: - test_concat_mla_k uses (atol=0, rtol=0) — exact copy semantics. - test_xqa_batch_decode uses (atol=1e-2, rtol=1e-2) — matches its trtllm_batch_decode sibling (both backends, same math). - test_xqa_batch_decode_mla uses _close_pass_ratio with (atol=0.05, rtol=0.05, pass_ratio=0.95), matching tests/attention/test_xqa.py for FP8 MLA outputs. Skipped on non-SM120/121 (kernel hardware-gated). - test_rope_quantize_fp8_append uses (atol=1e-2, rtol=2e-1) for Q, matching tests/attention/test_rope.py. Also fixes pre-existing tolerance mismatches uncovered by the audit: - silu_and_mul / gelu_and_mul / gelu_tanh_and_mul: switch test inputs from bf16 to fp16 (matches tests/utils/test_activation.py — bf16 ULP is 3e-2, exceeds the 1e-3 tolerance the existing test uses). example.py adds calls for all 4 new APIs to drive auto-dump on DeepSeek-R1-style runs. On B200: 377 passed, 6 skipped in tests/trace/ (4 hardware gates + xqa_batch_decode_mla SM120-only + cuDNN libcudart conflict). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires @flashinfer_api(trace=...) on the batched matmul APIs: - flashinfer.bmm_bf16 - flashinfer.bmm_fp8 - flashinfer.bmm_mxfp8 New TraceTemplates in flashinfer/trace/templates/gemm.py: - bmm_bf16_trace: standard batched matmul C[b] = A[b] @ B[b]. - bmm_fp8_trace: per-tensor FP8 BMM with A_scale/B_scale dequant scalars. - bmm_mxfp8_trace: MXFP8 BMM (MX block size 32) with uint8 block scales. Correctness tests use cosine similarity (cos_sim > 0.99), matching the existing tests/gemm/test_bmm_*.py convention. The bmm_fp8 test is gated to SM100/103 since the bf16 path is the only widely-available fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…spose, selective_state_update Adds @flashinfer_api(trace=...) for three previously-untraced public APIs that DeepSeek-R1 / Mamba / SM100 inference paths can hit: - flashinfer.norm.fused_rmsnorm_silu (new fused_rmsnorm_silu_trace in trace/templates/norm.py; reference is SiLU(RMSNorm(x))). - flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose (reuses existing gated_delta_rule_decode_trace — same math, just a pretranspose variant of the same kernel). - flashinfer.mamba.selective_state_update (new selective_state_update_trace in trace/templates/mamba.py; reference implements the SSM discrete recurrence dA = exp(dt'*A); state = state*dA + dB*x.unsqueeze(-1); y = state @ C + D*x [* silu(z)] for the single-token decode shapes). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ek, tinygemm_bf16, fmha_v2_prefill_deepseek Wires @flashinfer_api(trace=...) on four more public APIs: - flashinfer.topk.top_k_ragged_transform — fused per-row top-k + index-rebase used in sparse attention's second stage. - flashinfer.fused_moe.fused_topk_deepseek — DeepSeek-V3 grouped expert routing (sigmoid+bias score → top-k group → top-k expert → normalize). Reference implements all 5 routing steps and writes topk_values / topk_indices in-place. - flashinfer.gemm.tinygemm_bf16 — SM90+ small-batch bf16 GEMM (out = input @ weight.T + bias). Reference is straight F.linear. - flashinfer.prefill.fmha_v2_prefill_deepseek — separate Q/K/V causal SDPA with fixed seq_len. Distinct from the (packed-QKV) trtllm_fmha_v2_prefill trace we already have. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two new TraceTemplates and wires @flashinfer_api(trace=...) on: - flashinfer.cute_dsl.attention.wrappers.batch_mla.CuteDslBatchMLAPagedAttentionWrapper.run (cute_dsl_batch_mla_run_trace): same MLA decode math as trtllm_batch_decode_with_kv_cache_mla but with the CuteDSL kernel's signature (q / softmax_scale / output_scale instead of query / bmm1_scale / bmm2_scale). Reference assumes DeepSeek-V3 default qk_rope_head_dim=64; users with non-default kpe sizes can re-trace. - flashinfer.cute_dsl.attention.wrappers.batch_prefill.CuteDslBatchPrefillWrapper.run (cute_dsl_batch_prefill_run_trace): ragged batch prefill on separate q/k/v with indptr baked into plan(). Reference is causal SDPA per head, treating the full ragged tensor as a single sequence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two new TraceTemplates and wires @flashinfer_api(trace=...) on both CuteDSL fused norm + FP4-quantize APIs: - flashinfer.cute_dsl.rmsnorm_fp4quant.rmsnorm_fp4quant (rmsnorm_fp4quant_trace): RMSNorm * weight, optional global scaling, then per-block FP4 (e2m1) quantize. Reference models the math; the kernel uses a bespoke shuffled scale layout that the reference does not reproduce — use dequantized round-trip for correctness. - flashinfer.cute_dsl.add_rmsnorm_fp4quant.add_rmsnorm_fp4quant (add_rmsnorm_fp4quant_trace): residual += input then the rmsnorm_fp4quant pipeline; mutates residual in-place with the prenorm sum. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…cp_a2a_alltoall, fp8_blockscale_gemm_sm90, grouped_gemm_nt_masked Final batch of Priority-2 trace coverage. Wires @flashinfer_api(trace=...) on four more public APIs: - flashinfer.activation.silu_and_mul_scaled_nvfp4_experts_quantize (silu_and_mul_scaled_nvfp4_experts_quantize_trace): SiLU+mul activation on a 3-D batched MoE expert tensor [B, M, 2K] followed by masked NVFP4 (e2m1, block_size=16) quantization. - flashinfer.comm.decode_cp_a2a_alltoall (decode_cp_a2a_alltoall_trace): context-parallel attention all-to-all reduction. Single-rank reference is a passthrough; multi-rank correctness lives in tests/comm. - flashinfer.gemm.fp8_blockscale_gemm_sm90 (fp8_blockscale_gemm_sm90_trace): SM90 FP8 block-scale GEMM with auto-swapAB. Reference dequantizes via per-block scales (1x128 for input, 128x128 for weight) then matmul. - flashinfer.gemm.kernels.grouped_gemm_nt_masked (grouped_gemm_nt_masked_trace): Blackwell grouped GEMM with masked-M per group, used in MoE expert FC2 path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…, batch_deepgemm_fp8_nt_groupwise, mm_M1_16_K7168_N256, nvfp4_kv_quantize, trtllm_ragged_attention_deepseek Adds TraceTemplates and @flashinfer_api(trace=...) wiring for the remaining DeepSeek R1 inference-path APIs identified in the audit: - top_k_page_table_transform: fused per-row top-k + page-table translation - batch_deepgemm_fp8_nt_groupwise: batched FP8 group-wise GEMM (masked-M) - mm_M1_16_K7168_N256: DeepSeek-V3 router GEMM specialization - nvfp4_kv_quantize: NVFP4 KV-cache quantization (linear SF layout) - trtllm_ragged_attention_deepseek: DeepSeek ragged-batch attention All templates ship with reference implementations (where applicable) and pass tests/trace/test_fi_trace_template_consistency.py (356 → 356, +15 new) and tests/trace/test_reference_correctness.py (61 passed, 8 skipped — no regressions). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t tests Cross-check every tolerance in tests/trace/test_reference_correctness.py against the upstream API's own unit test. Three deviations fixed: - mm_bf16: upstream tests/gemm/test_mm_bf16.py uses cos_sim > 0.99 (same convention as tgv_gemm/bmm_bf16). Replace atol=5e-1/rtol=5e-2 with _close_fp8(cos_sim_min=0.99). - rmsnorm_quant: upstream tests/utils/test_norm.py (line 156) uses atol=1, rtol=1 on dequantized FP8 output. Relax from atol/rtol=0.3. - fused_add_rmsnorm_quant: same upstream tolerance for dequantized FP8 path (line 264). Tighten residual comparison to 1e-3 per upstream. - softmax / top_k_renorm_probs: upstream tests/utils/test_sampling.py lines 443/477 use 1e-3. Tighten from 5e-3. No kernel changes. 61 passed, 8 skipped (unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Tested PR flashinfer-ai#424's Path A workflow on Llama-3.2-3B (TP=1). The trace dumper in flashinfer-ai/flashinfer#2931 emits JSONs with three mismatches against flashinfer-bench's Definition validator: 1. reference declares def _<name>_reference(...) instead of def run(...) 2. plan-time index tensors (kv_indptr, kv_indices, qo_indptr) get dtype="unknown" because the dumper only inspects run() kwargs 3. in-place ops (fused_add_rmsnorm) declare the same name in both inputs and outputs, which the validator rejects Add a §A3b post-staging normalization snippet to extract-kernel-definitions that patches all three so the staged JSONs validate. Cross-reference it from collect-workloads §Phase 2b for the piggyback case. These patches should land upstream in FlashInfer so A3b becomes a no-op. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ce dump (#424) ## Summary Update the onboarding skills to use FlashInfer's new `@flashinfer_api(trace=...)` trace dumper ([flashinfer-ai/flashinfer#2931](flashinfer-ai/flashinfer#2931)) as the primary way to produce Definition JSONs. The dumper writes a per-(op, shape) JSON to `FLASHINFER_TRACE_DUMP_DIR` before each decorated kernel runs (crash-safe, deduplicated), so a single short SGLang inference pass replaces most of the manual SGLang-source parsing the agent used to do. ## What changed - **`extract-kernel-definitions`** rewritten as a two-path skill (file: 885 → ~330 lines): - **Path A (primary)** — run a short SGLang `Engine.generate(...)` with `FLASHINFER_TRACE_DUMP=1`, `FLASHINFER_TRACE_DUMP_DIR=<dir>`, and `attention_backend=\"flashinfer\"`. Stage the resulting JSONs into `tmp/flashinfer-trace/definitions/{op_type}/`. Tags `fi_api:*` and `status:verified` plus the `reference` implementation are emitted by the dumper; `tp:*`/`ep:*`/`model:*` get appended in a small post-step. - **Path B (fallback)** — manual SGLang-source extraction for `fi_missing` kernels or APIs not yet decorated. Condensed from the old 885-line walkthrough to a per-op-type formula table + a small JSON-writing recipe. - Includes the FlashInfer trace coverage table (lifted from [`docs/fi_trace.rst`](https://github.com/flashinfer-ai/flashinfer/blob/main/docs/fi_trace.rst)) and pitfalls (env-var-before-import ordering, disable CUDA graphs, page-size runs, MoE routing variants). - **`collect-workloads`** gains a \"Phase 2b: piggyback definition trace dump\" section — setting the trace env vars alongside the existing `FLASHINFER_DUMP_*` logging vars produces workloads **and** definitions from one SGLang run, useful when a shape was missed during Phase 2. - **`discover-models`** Phase 1e now grep-checks for the `@flashinfer_api(trace=...)` decorator and records `fi_trace_template` (true/false) on each manifest entry. Phase 2 reads this to decide between Path A and Path B. - **`onboard-model`** Phase 2 description + manifest schema updated to reflect the new path. Manifest entries gain `fi_trace_template` and `phase2_method` fields. ## Why - Eliminates manual axis derivation, name generation, and `reference`-impl writing for ~95% of kernels (anything in the FlashInfer trace registry). - Aligns flashinfer-bench with the canonical schema produced by FlashInfer itself — no more drift between hand-written definitions and the templates in `flashinfer/trace/templates/*.py`. - Lets workload collection double as definition collection at zero extra cost. ## Test plan - [ ] Spot-check the trace dumper on a small model (e.g. `Llama-3.2-3B`) using the snippet in Path A2; confirm the dump dir contains the expected definitions and that staging into `tmp/flashinfer-trace/definitions/` produces a clean `flashinfer-bench validate` pass. - [ ] Walk the onboard-model overview table → confirm Phase 2 description matches the rewritten extract-kernel-definitions content. - [ ] Sanity-check the discover-models output schema (`fi_trace_template`) matches what onboard-model + extract-kernel-definitions claim to read. ## References - Feature PR: [flashinfer-ai/flashinfer#2931](flashinfer-ai/flashinfer#2931) - Schema docs: [`flashinfer/docs/fi_trace.rst`](https://github.com/flashinfer-ai/flashinfer/blob/main/docs/fi_trace.rst) - SGLang harness reference: [`flashinfer/tests/trace/example_sglang.py`](https://github.com/flashinfer-ai/flashinfer/blob/main/tests/trace/example_sglang.py) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Documentation** * Added an optional trace-dump phase to emit per-kernel Definition JSONs during workload collection and instructions to stage them into datasets. * Added decorator-based readiness detection and a new manifest field fi_trace_template to select automated (trace) vs manual extraction. * Rewrote extraction guidance to prioritize a trace-driven path with a clear manual fallback; updated onboarding, phase2_method notes, CLI guidance, and examples. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: Avery Huang <averyh@nvidia.com>
📌 Description
Adds a trace layer to FlashInfer so every public kernel can be described as a portable benchmark / replay definition without tying the description to any particular launcher.
flashinfer/trace/template.py— newTraceTemplateschema with namedaxes(Var/Const), typedinputs/outputs(Tensor/Scalar), optionalreferenceimplementation, and tag/constraint metadata.flashinfer/trace/templates/*.py— one module per operator family (attention, cascade, GDN, GEMM, MoE, norm, activation, sampling). Each file declares the schema and, where feasible, an executable reference.@flashinfer_api(trace=...)(extends the existing decorator inflashinfer/api_logging.py) — attaches.fi_trace()to the decorated function/method and, whenFLASHINFER_TRACE_DUMP=1, writes a per-shape JSON definition toFLASHINFER_TRACE_DUMP_DIRbefore the kernel runs (crash-safe).fi_trace()helpers — public entry points for programmatic trace generation from any@flashinfer_api-decorated API or a bound method.tests/trace/— template-consistency tests (signature ↔ axes/inputs), end-to-end reference checks, and anexample.pythat drives a realistic workload and dumps 45tests/trace/fi_trace_out/*.jsondefinitions across LLaMA-3.1, DeepSeek-V3, Gemma, Qwen3-Next, etc.Why
flashinfer-bench) consume a single self-describing JSON per op instead of reverse-engineering Python call sites.FLASHINFER_LOGLEVEL=0/FLASHINFER_TRACE_DUMPunset (decorator is a no-op in that path).Covered APIs
Attention (paged/ragged prefill, paged decode, MLA), sampling (top-k/top-p/top-k-top-p), GEMM (bf16, fp8, mxfp8, fp4), fused MoE (fp8/fp4 block-scale × 6 routing methods), norm (rmsnorm, fused-add, quant variants, Gemma variants, layernorm), activation (silu/gelu/gelu-tanh + mul), cascade merge (state/state-in-place/states), and GDN (decode, MTP, chunk prefill).
🔍 Related Issues
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit.pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
tests/trace/test_fi_trace.py,tests/trace/test_fi_trace_template_consistency.py).pytest tests/trace/ -v→ 139 passed).Reviewer Notes
@flashinfer_api(trace=...)must be innermost so trace dump runs even when surrounding@backend_requirementraises for unsupported capability. Formm_fp4/mm_mxfp8on SM<100 the outer@backend_requirementraises before the dump, which is why their JSONs are only regenerated on Blackwell. Seetests/trace/example.pyfor the realistic workload.@flashinfer_apidecorators that caused double-logging atFLASHINFER_LOGLEVEL=3+(subclass__init__overrides andtrtllm_low_latency_gemminternal helper).H/I/top_k/n_group;gdn_prefill_tracelacks the head-ratio constraints thatgdn_decode/gdn_mtpalready have; the E2E test synthesizer uses0forint32inputs, which makes some synthesized definitions nonsensical.🤖 Generated with Claude Code