From 2551fdb136303438eb5c34b07b3b27fd68f8eaf2 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Wed, 11 Mar 2026 09:17:45 +0000 Subject: [PATCH 01/10] First version without jaggedFA. --- examples/sid_gr/beam_search/beam_search.py | 79 +++++ examples/sid_gr/model/attention_mask.py | 175 ++++++++++ examples/sid_gr/model/gpt_model.py | 211 +++++++++++ .../sid_gr/tests/test_method_a_generate.py | 330 ++++++++++++++++++ 4 files changed, 795 insertions(+) create mode 100644 examples/sid_gr/tests/test_method_a_generate.py diff --git a/examples/sid_gr/beam_search/beam_search.py b/examples/sid_gr/beam_search/beam_search.py index 72274ff5b..0f8c8f55a 100644 --- a/examples/sid_gr/beam_search/beam_search.py +++ b/examples/sid_gr/beam_search/beam_search.py @@ -45,6 +45,9 @@ def __init__( self.history_topk_sids: List[torch.Tensor] = [] self.history_accumulate_topk_probs: List[torch.Tensor] = [] self.history_probs: List[torch.Tensor] = [] + + # parent beam indices per step for Method A ancestor tracking + self.parent_indices: List[torch.Tensor] = [] self.reset() def propagate( @@ -113,6 +116,7 @@ def propagate( self.history_topk_sids.append(generated_sids) self.history_accumulate_topk_probs.append(torch.exp(topk_probs)) self.history_probs.append(torch.exp(log_probs_this_step)) + self.parent_indices.append(last_step_indices) self.generated_sids = generated_sids self.accumulated_log_probs = topk_probs self.step += 1 @@ -126,6 +130,7 @@ def reset(self): self.history_topk_sids = [] self.history_accumulate_topk_probs = [] self.history_probs = [] + self.parent_indices = [] def get_sids( self, @@ -143,6 +148,80 @@ def get_sids( else: raise ValueError(f"Step {step} is not valid, current step is {self.step}") + def get_ancestor_positions( + self, + hist_len: torch.Tensor, + ) -> torch.Tensor: + """ + Compute physical positions of each current beam's ancestors in a + Method A (Incremental Append) sequence layout. + + In Method A, the sequence is laid out as: + [history..., step0_beam0, step0_beam1, ..., step1_beam0, step1_beam1, ...] + + The physical position of the token generated at step s for beam b is: + hist_len + sum(beam_widths[:s]) + b + + This method traces back through parent_indices to find, for each + current beam, its ancestor beam index at every previous step. + + Args: + hist_len: [B] or scalar, the history length per sample (including BOS). + + Returns: + ancestor_positions: [B, topk_current, step] int64 tensor. + ancestor_positions[b, k, s] = physical position of beam k's + ancestor at step s in the Method A sequence. + Returns None if step == 0 (no ancestors yet). + """ + if self.step == 0: + return None + + batch_size = self.parent_indices[0].shape[0] + topk_current = self.parent_indices[-1].shape[1] + device = self.parent_indices[0].device + + # beam_index_at_step[s] has shape [B, topk_current]: + # the beam index that each current beam occupied at step s + beam_index_at_step: List[torch.Tensor] = [None] * self.step + + # Start from the current beams and trace backwards + current_beam_idx = torch.arange( + topk_current, device=device + ).unsqueeze(0).expand(batch_size, -1) + beam_index_at_step[self.step - 1] = current_beam_idx + + for s in range(self.step - 1, 0, -1): + parent_idx = self.parent_indices[s] # [B, topk_at_step_s] + current_beam_idx = torch.gather( + parent_idx, dim=1, index=current_beam_idx + ) + beam_index_at_step[s - 1] = current_beam_idx + + # Convert beam indices to physical positions + # physical_pos(step=s, beam=b) = hist_len + sum(beam_widths[:s]) + b + if not isinstance(hist_len, torch.Tensor): + hist_len = torch.tensor( + [hist_len] * batch_size, device=device, dtype=torch.long + ) + hist_len = hist_len.unsqueeze(1) # [B, 1] + + step_offsets = [] + cumulative = 0 + for s in range(self.step): + step_offsets.append(cumulative) + cumulative += self.beam_widths[s] + + ancestor_positions = torch.stack( + [ + hist_len + step_offsets[s] + beam_index_at_step[s] + for s in range(self.step) + ], + dim=-1, + ) # [B, topk_current, step] + + return ancestor_positions + def generate_valid_mask(self) -> torch.Tensor: """ update the valid mask between current step and previous step, diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py index 63234f6df..5ef1301e4 100644 --- a/examples/sid_gr/model/attention_mask.py +++ b/examples/sid_gr/model/attention_mask.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import torch @@ -207,6 +209,179 @@ def padded_target_aware_causal_mask( return ~valid_attn_mask +def build_incremental_append_arbitrary_func( + history_seqlens: torch.Tensor, + max_history_seqlen: int, + current_step: int, + beam_widths: List[int], + ancestor_positions: torch.Tensor, +) -> torch.Tensor: + """ + Build an arbitrary_func tensor for the Method A (Incremental Append) beam + search layout, compatible with flash_attn's interval-based mask encoding. + + Method A sequence layout (at the time of predicting step ``current_step``): + [history..., step0_beams, step1_beams, ..., step_{current_step-1}_beams] + + Mask semantics: + - History tokens: causal within history (token i sees [0, i+1)) + - Generated token at (step s, beam b): sees + [0, hist_len) (all history) + plus its ancestor chain encoded via ancestor_positions. + - The last step's tokens (step current_step-1) use precise tree mask. + - Earlier steps' tokens use a causal-like approximation. + + Args: + history_seqlens: [B] per-sample history lengths. + max_history_seqlen: max history length (for padding). + current_step: the hierarchy step being predicted (0-based). + The sequence contains tokens from steps 0..current_step-1. + beam_widths: list of beam widths per step. + ancestor_positions: [B, topk_current, current_step] physical positions + of each beam's ancestors at steps 0..current_step-1. + None if current_step == 0. + + Returns: + arbitrary_func: [B, 1, n_func, total_seqlen + 256] int32 tensor. + """ + device = history_seqlens.device + B = history_seqlens.shape[0] + total_generated = sum(beam_widths[:current_step]) + total_seqlen = max_history_seqlen + total_generated + + if current_step == 0: + # Pure causal mask over history, n_func=1 + af = torch.zeros( + B, 1, 1, total_seqlen + 256, dtype=torch.int32, device=device + ) + hist_positions = torch.arange(max_history_seqlen, device=device) + hist_f0 = (hist_positions + 1).unsqueeze(0).expand(B, -1) + hist_f0 = torch.min( + hist_f0, history_seqlens.unsqueeze(1).expand_as(hist_f0) + ) + af[:, 0, 0, :max_history_seqlen] = hist_f0.to(torch.int32) + return af + + topk_current = beam_widths[current_step - 1] + + # n_func must be odd. The last-step tokens need at most + # (current_step + 1) intervals: 1 (history) + current_step (ancestors incl self). + max_intervals = current_step + 1 + n_func = max(2 * max_intervals - 1, 3) + + af = torch.zeros( + B, 1, n_func, total_seqlen + 256, dtype=torch.int32, device=device + ) + + # --- History tokens: causal mask, F0 = pos + 1 --- + hist_positions = torch.arange(max_history_seqlen, device=device) + hist_f0 = (hist_positions + 1).unsqueeze(0).expand(B, -1) + hist_f0 = torch.min( + hist_f0, history_seqlens.unsqueeze(1).expand_as(hist_f0) + ) + af[:, 0, 0, :max_history_seqlen] = hist_f0.to(torch.int32) + + # --- Generated tokens --- + gen_offset = 0 + for s in range(current_step): + bw_s = beam_widths[s] + gen_start = max_history_seqlen + gen_offset + is_last_step = s == current_step - 1 + + for b in range(bw_s): + pos = gen_start + b + + # F0: base interval = [0, hist_len) — see all history + af[:, 0, 0, pos] = history_seqlens.to(torch.int32) + + if is_last_step and ancestor_positions is not None: + # Tree mask: precise ancestor intervals from ancestor_positions. + # ancestor_positions[:, b, :] has positions at steps 0..current_step-1. + # Steps 0..current_step-2 are true ancestors; step current_step-1 is self. + for anc_idx in range(current_step): + anc_pos = ancestor_positions[:, b, anc_idx] # [B] + af[:, 0, 2 * anc_idx + 1, pos] = anc_pos.to(torch.int32) + af[:, 0, 2 * anc_idx + 2, pos] = (anc_pos + 1).to( + torch.int32 + ) + else: + # Earlier steps' tokens: causal over [hist_len, pos+1) + af[:, 0, 1, pos] = max_history_seqlen + af[:, 0, 2, pos] = pos + 1 + + gen_offset += bw_s + + return af + + +def build_incremental_append_dense_mask( + history_seqlens: torch.Tensor, + max_history_seqlen: int, + current_step: int, + beam_widths: List[int], + ancestor_positions: torch.Tensor, +) -> torch.Tensor: + """ + Build a dense [B, N, N] bool mask for Method A layout — for testing and + validation against arbitrary_func encoding. + + Args: + current_step: the hierarchy step being predicted (0-based). + The sequence contains tokens from steps 0..current_step-1. + ancestor_positions: [B, topk, current_step] or None. + Covers steps 0..current_step-1. The last entry is the self position. + + Returns: + mask: [B, N, N] bool tensor where True = can attend. + """ + device = history_seqlens.device + B = history_seqlens.shape[0] + total_generated = sum(beam_widths[:current_step]) + N = max_history_seqlen + total_generated + + mask = torch.zeros(B, N, N, dtype=torch.bool, device=device) + + # History region: causal + for b_idx in range(B): + hl = history_seqlens[b_idx].item() + for i in range(hl): + mask[b_idx, i, : i + 1] = True + + if current_step == 0: + return mask + + topk_current = beam_widths[current_step - 1] + + # Generated tokens from earlier steps (steps 0..current_step-2): causal + gen_offset = 0 + for s in range(current_step - 1): + bw_s = beam_widths[s] + gen_start = max_history_seqlen + gen_offset + for b in range(bw_s): + pos = gen_start + b + for b_idx in range(B): + hl = history_seqlens[b_idx].item() + mask[b_idx, pos, :hl] = True + mask[b_idx, pos, max_history_seqlen : pos + 1] = True + gen_offset += bw_s + + # Last step tokens (step current_step-1): precise tree mask + gen_start = max_history_seqlen + gen_offset + for b in range(topk_current): + pos = gen_start + b + for b_idx in range(B): + hl = history_seqlens[b_idx].item() + mask[b_idx, pos, :hl] = True # history + if ancestor_positions is not None: + for anc_step in range(current_step): + anc_pos = ancestor_positions[b_idx, b, anc_step].item() + mask[b_idx, pos, anc_pos] = True + else: + mask[b_idx, pos, pos] = True # self only + + return mask + + if __name__ == "__main__": history_seqlen = torch.tensor([4, 3]).cuda() max_history_seqlen = 6 diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index a59f555ba..fbf2243b6 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -37,6 +37,8 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from .attention_mask import ( + build_incremental_append_arbitrary_func, + build_incremental_append_dense_mask, padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) @@ -776,3 +778,212 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: generated_sids = self.beam_search.get_sids() log_probs = self.beam_search.get_log_probs() return generated_sids, log_probs + + @torch.no_grad + def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: + """ + Generate using Method A (Incremental Append) token layout. + + Sequence layout at step i: + [history+BOS, step0_beams, step1_beams, ..., step_{i-1}_beams] + + Tokens are appended in generation order. The attention mask is a + tree-shaped arbitrary mask: each beam token sees all history plus + its ancestor chain, but not tokens from other beam branches. + + This method uses the existing dense-mask + pad-to-dense decoder path + for correctness validation. The arbitrary_func tensor is also built + for future use with jagged Flash Attention. + """ + # 0. prepare history+BOS embeddings (identical to generate()) + ( + history_embeddings, + input_offsets, + input_max_seqlen, + ) = self._prepare_embeddings( + batch, add_bos_to_history=False, is_generation=True + ) + batch_size = batch.actual_batch_size + input_offsets = input_offsets[: batch_size + 1] + history_seqlens = torch.diff(input_offsets) # [B] + + self.beam_search.reset() + + # physical_step_codes[s] stores the codes selected at step s + # shape: [batch_size, beam_widths[s]] + physical_step_codes: List[torch.Tensor] = [] + + for i in range(self._num_hierarchies): + # 1. Build the Method A sequence + if i == 0: + # No generated tokens yet; sequence is just [history+BOS] + cated_hidden_states = history_embeddings + cated_offsets = input_offsets + cated_max_seqlen = input_max_seqlen + topk_prev_step = 1 + else: + # Embed all previous steps' codes in Method A order: + # [step0_beam0, step0_beam1, ..., step1_beam0, ...] + all_gen_codes = torch.cat( + [sc.view(-1) for sc in physical_step_codes] + ) # [total_gen_tokens] + total_gen_per_sample = sum( + self.beam_search.beam_widths[s] for s in range(i) + ) + + gen_sids_kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[ + batch.candidate_feature_name, + batch.history_feature_name, + ], + values=all_gen_codes, + lengths=torch.cat( + [ + torch.full( + (batch_size,), + total_gen_per_sample, + device=all_gen_codes.device, + dtype=torch.long, + ), + torch.zeros( + (batch_size,), + device=all_gen_codes.device, + dtype=torch.long, + ), + ] + ), + ) + gen_embeddings = ( + self._codebooks_collection(gen_sids_kjt)[ + batch.candidate_feature_name + ] + .values() + .to(self._training_dtype) + ) # [batch_size * total_gen_per_sample, D] + + gen_offsets = torch.arange( + 0, + batch_size + 1, + device=input_offsets.device, + dtype=input_offsets.dtype, + ) * total_gen_per_sample + + ( + cated_hidden_states, + cated_offsets, + cated_max_seqlen, + ) = self._concat_jagged( + [history_embeddings, gen_embeddings], + [input_offsets, gen_offsets], + [input_max_seqlen, total_gen_per_sample], + ) + topk_prev_step = self.beam_search.beam_widths[i - 1] + + # 2. Build the Method A attention mask. + # ancestor_positions covers the beams from the latest propagate. + # At step i, the sequence has tokens from steps 0..i-1. + # ancestor_positions has shape [B, topk, i] covering steps 0..i-1. + # None when i == 0 (no beams exist yet). + ancestor_positions = self.beam_search.get_ancestor_positions( + history_seqlens + ) + + dense_mask = build_incremental_append_dense_mask( + history_seqlens, + input_max_seqlen, + current_step=i, + beam_widths=self.beam_search.beam_widths, + ancestor_positions=ancestor_positions, + ) + attention_mask = ~dense_mask.unsqueeze(1) # [B, 1, N, N] + + # Also build arbitrary_func for future jagged FA integration + _arbitrary_func = build_incremental_append_arbitrary_func( + history_seqlens, + input_max_seqlen, + current_step=i, + beam_widths=self.beam_search.beam_widths, + ancestor_positions=ancestor_positions, + ) # stored for debugging / future use + + # 3. Decoder step (pad-to-dense path) + jagged_output_hidden_states = self.decoder_step( + cated_hidden_states, + cated_offsets, + cated_max_seqlen, + attention_mask=attention_mask, + padding_to_dense=True, + add_bos_to_history=False, + ) + + # 4. Extract the logit-producing tokens from the output. + # At step 0: the last 1 token per sample (BOS). + # At step i > 0: the last beam_widths[i-1] tokens (step i-1 beams). + if i == 0: + # BOS is the last token in history; extract 1 per sample + bos_offsets = torch.arange( + 0, + batch_size + 1, + device=input_offsets.device, + dtype=input_offsets.dtype, + ).clamp(max=batch_size) + _, candidate_hidden_states = triton_split_2D_jagged( + jagged_output_hidden_states, + max_seq_len=cated_max_seqlen, + offsets_a=cated_offsets - bos_offsets, + offsets_b=bos_offsets, + ) + candidate_hidden_states = candidate_hidden_states.view( + batch_size, 1, self.embedding_dim + ) + else: + # Last beam_widths[i-1] tokens per sample + logit_token_count = topk_prev_step + logit_offsets = torch.arange( + 0, + batch_size + 1, + device=input_offsets.device, + dtype=input_offsets.dtype, + ) * logit_token_count + + _, candidate_hidden_states = triton_split_2D_jagged( + jagged_output_hidden_states, + max_seq_len=cated_max_seqlen, + offsets_a=cated_offsets - logit_offsets, + offsets_b=logit_offsets, + ) + candidate_hidden_states = candidate_hidden_states.view( + batch_size, topk_prev_step, self.embedding_dim + ) + + # 5. MLP projection → logits → log_softmax + mlp = ( + self._decoder_mlp[i] + if not self.share_lm_head_across_hierarchies + else self._decoder_mlp + ) + tuple_or_tensor: Union[ + Tuple[torch.Tensor, torch.Tensor], torch.Tensor + ] = mlp(candidate_hidden_states) + candidates_logits = ( + tuple_or_tensor[0] + if isinstance(tuple_or_tensor, tuple) + else tuple_or_tensor + ) + probs_this_step: torch.Tensor = torch.nn.functional.log_softmax( + candidates_logits.float(), dim=-1 + ) + + # 6. Beam search propagate + self.beam_search.propagate(probs_this_step) + + # 7. Record the physical step codes for the next iteration. + # history_topk_sids[i][:, :, i] gives the codes selected at step i + # for the beams that survived step i's top-k. + physical_step_codes.append( + self.beam_search.history_topk_sids[-1][:, :, i] + ) + + generated_sids = self.beam_search.get_sids() + log_probs = self.beam_search.get_log_probs() + return generated_sids, log_probs diff --git a/examples/sid_gr/tests/test_method_a_generate.py b/examples/sid_gr/tests/test_method_a_generate.py new file mode 100644 index 000000000..c44bc8ed0 --- /dev/null +++ b/examples/sid_gr/tests/test_method_a_generate.py @@ -0,0 +1,330 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for Method A (Incremental Append) beam search: + 1. Beam search ancestor tracking. + 2. Dense mask structure for Method A. + 3. arbitrary_func matches the dense reference mask. +""" + +import os +import sys + +import pytest +import torch + +from beam_search.beam_search import BeamSearch + +# Import attention_mask directly to avoid model/__init__.py which pulls in +# heavy dependencies (dynamicemb, megatron, torchrec). +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from attention_mask import ( + build_incremental_append_arbitrary_func, + build_incremental_append_dense_mask, +) +sys.path.pop(0) + + +# --------------------------------------------------------------------------- +# Helper: expand arbitrary_func to dense bool mask for comparison +# --------------------------------------------------------------------------- +def arbitrary_func_to_dense( + af: torch.Tensor, + seqlen_q: int, + seqlen_k: int, +) -> torch.Tensor: + """ + Expand an arbitrary_func tensor [B, 1, n_func, seqlen_q+256] into a dense + [B, seqlen_q, seqlen_k] bool mask using the interval semantics: + valid(q, k) = (k < F0[q]) OR (F1[q] <= k < F2[q]) OR ... + """ + B = af.shape[0] + n_func = af.shape[2] + mask = torch.zeros(B, seqlen_q, seqlen_k, dtype=torch.bool, device=af.device) + + kv_idx = torch.arange(seqlen_k, device=af.device) + + for b in range(B): + for q in range(seqlen_q): + f0 = af[b, 0, 0, q].item() + row_mask = kv_idx < f0 + for interval in range(n_func // 2): + f_start = af[b, 0, 2 * interval + 1, q].item() + f_end = af[b, 0, 2 * interval + 2, q].item() + row_mask = row_mask | ((kv_idx >= f_start) & (kv_idx < f_end)) + mask[b, q] = row_mask + + return mask + + +def _run_beam_search(batch_size, beam_width, num_hierarchies, codebook_sizes): + """Run a full beam search with random logits and return the BeamSearch object.""" + bs = BeamSearch(beam_width, num_hierarchies, codebook_sizes, record_history=True) + topk_prev = 1 + for step in range(num_hierarchies): + log_probs = torch.randn( + batch_size, topk_prev, codebook_sizes[step], device="cuda" + ) + bs.propagate(log_probs) + topk_prev = beam_width + return bs + + +# --------------------------------------------------------------------------- +# Test: beam search ancestor tracking +# --------------------------------------------------------------------------- +class TestBeamSearchAncestorTracking: + @pytest.mark.parametrize("batch_size", [1, 3]) + @pytest.mark.parametrize("beam_width", [2, 4]) + def test_ancestor_positions_shape(self, batch_size, beam_width): + """Verify shape and range of ancestor positions.""" + num_h = 3 + codebook_sizes = [10] * num_h + bs = _run_beam_search(batch_size, beam_width, num_h, codebook_sizes) + + hist_len = torch.full((batch_size,), 6, device="cuda", dtype=torch.long) + anc_pos = bs.get_ancestor_positions(hist_len) + + assert anc_pos is not None + assert anc_pos.shape == (batch_size, beam_width, num_h) + + for s in range(num_h): + offset = sum(bs.beam_widths[:s]) + for b_idx in range(batch_size): + hl = hist_len[b_idx].item() + positions = anc_pos[b_idx, :, s] + assert torch.all(positions >= hl + offset) + assert torch.all(positions < hl + offset + bs.beam_widths[s]) + + def test_ancestor_positions_step_zero(self): + """At step 0, get_ancestor_positions should return None.""" + bs = BeamSearch(2, 3, [10, 10, 10], record_history=True) + assert bs.get_ancestor_positions(torch.tensor([6], device="cuda")) is None + + def test_parent_indices_stored(self): + """Verify parent_indices are stored during propagate.""" + bs = _run_beam_search(1, 2, 3, [10, 10, 10]) + assert len(bs.parent_indices) == 3 + for pi in bs.parent_indices: + assert pi.shape[0] == 1 + assert pi.shape[1] == 2 + + def test_reset_clears_parent_indices(self): + """Verify reset() clears parent_indices.""" + bs = _run_beam_search(1, 2, 2, [10, 10]) + assert len(bs.parent_indices) == 2 + bs.reset() + assert len(bs.parent_indices) == 0 + + def test_ancestor_self_position(self): + """The last entry in ancestor_positions should be the token's own position.""" + bs = _run_beam_search(1, 2, 3, [10, 10, 10]) + hist_len = torch.tensor([6], device="cuda", dtype=torch.long) + anc_pos = bs.get_ancestor_positions(hist_len) + + last_step_offset = sum(bs.beam_widths[:2]) + for b in range(2): + expected_self_pos = 6 + last_step_offset + b + assert anc_pos[0, b, 2].item() == expected_self_pos + + +# --------------------------------------------------------------------------- +# Test: dense mask structure for Method A +# --------------------------------------------------------------------------- +class TestMethodADenseMask: + def test_step0_pure_causal(self): + """At step 0, mask is pure causal over history.""" + B, hist_len, max_hist = 1, 4, 4 + history_seqlens = torch.tensor([hist_len], device="cuda") + + mask = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=0, + beam_widths=[2], ancestor_positions=None, + ) + assert mask.shape == (B, max_hist, max_hist) + expected = torch.tril(torch.ones(hist_len, hist_len, dtype=torch.bool, device="cuda")) + assert torch.equal(mask[0, :hist_len, :hist_len], expected) + + def test_step1_beam_isolation(self): + """At step 1, step-0 beams should not see each other.""" + B, hist_len, max_hist = 1, 4, 4 + beam_width = 2 + history_seqlens = torch.tensor([hist_len], device="cuda") + + bs = _run_beam_search(B, beam_width, 1, [10]) + anc_pos = bs.get_ancestor_positions(history_seqlens) + + mask = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=1, + beam_widths=[beam_width, beam_width], + ancestor_positions=anc_pos, + ) + N = max_hist + beam_width + assert mask.shape == (B, N, N) + + beam_a_pos = max_hist + beam_b_pos = max_hist + 1 + + # Beams should NOT see each other + assert not mask[0, beam_a_pos, beam_b_pos].item() + assert not mask[0, beam_b_pos, beam_a_pos].item() + + # Both should see self + assert mask[0, beam_a_pos, beam_a_pos].item() + assert mask[0, beam_b_pos, beam_b_pos].item() + + # Both should see all history + assert torch.all(mask[0, beam_a_pos, :hist_len]) + assert torch.all(mask[0, beam_b_pos, :hist_len]) + + def test_tree_ancestry_step2(self): + """At step 2, each beam sees its ancestor chain but not other branches.""" + B, hist_len, max_hist = 1, 4, 4 + beam_width = 2 + history_seqlens = torch.tensor([hist_len], device="cuda") + + bs = _run_beam_search(B, beam_width, 2, [10, 10]) + anc_pos = bs.get_ancestor_positions(history_seqlens) + + mask = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=2, + beam_widths=[beam_width, beam_width, beam_width], + ancestor_positions=anc_pos, + ) + N = max_hist + beam_width * 2 + assert mask.shape == (B, N, N) + + step1_start = max_hist + beam_width + for b in range(beam_width): + token_pos = step1_start + b + # Should see all history + assert torch.all(mask[0, token_pos, :hist_len]) + # Should see self + assert mask[0, token_pos, token_pos].item() + # Should see its step-0 ancestor + step0_ancestor = anc_pos[0, b, 0].item() + assert mask[0, token_pos, step0_ancestor].item() + + def test_padding_handled(self): + """Samples with different history lengths should have correct padding.""" + B = 2 + history_seqlens = torch.tensor([4, 2], device="cuda") + max_hist = 4 + + mask = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=0, + beam_widths=[2], ancestor_positions=None, + ) + # Sample 0: 4 history tokens + assert mask[0, 3, 3].item() # pos 3 sees self + # Sample 1: only 2 history tokens + assert mask[1, 1, 1].item() # pos 1 sees self + assert not mask[1, 2, 2].item() # pos 2 is padding, all zeros + assert not mask[1, 3, 3].item() # pos 3 is padding + + +# --------------------------------------------------------------------------- +# Test: arbitrary_func matches dense mask for last-step tokens +# --------------------------------------------------------------------------- +class TestArbitraryFuncMatchesDense: + @pytest.mark.parametrize("beam_width", [2, 3]) + @pytest.mark.parametrize("num_hierarchies", [1, 2, 3]) + def test_last_step_tokens_match(self, beam_width, num_hierarchies): + """ + The arbitrary_func encoding should produce the same mask as the dense + reference for the last step's tokens (the logit-producing tokens). + """ + B = 2 + hist_len = 6 + max_hist = hist_len + codebook_sizes = [10] * num_hierarchies + history_seqlens = torch.tensor([hist_len] * B, device="cuda") + beam_widths = [beam_width] * num_hierarchies + + bs = _run_beam_search(B, beam_width, num_hierarchies, codebook_sizes) + + # Test mask at the point of predicting the LAST hierarchy + test_step = num_hierarchies + anc_pos = bs.get_ancestor_positions(history_seqlens) + + dense = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=test_step, + beam_widths=beam_widths, ancestor_positions=anc_pos, + ) + + af = build_incremental_append_arbitrary_func( + history_seqlens, max_hist, current_step=test_step, + beam_widths=beam_widths, ancestor_positions=anc_pos, + ) + + total_gen = sum(beam_widths[:test_step]) + N = max_hist + total_gen + af_dense = arbitrary_func_to_dense(af, N, N) + + # Check that last-step tokens match exactly + step_start = max_hist + sum(beam_widths[: test_step - 1]) + for b_idx in range(B): + for pos in range(step_start, step_start + beam_width): + dense_row = dense[b_idx, pos, :N] + af_row = af_dense[b_idx, pos, :N] + assert torch.equal(dense_row, af_row), ( + f"Mismatch at batch={b_idx}, pos={pos}\n" + f" dense: {dense_row.int().tolist()}\n" + f" af: {af_row.int().tolist()}" + ) + + def test_step0_causal_match(self): + """At step 0, arbitrary_func should produce a pure causal mask.""" + B = 1 + hist_len = 6 + history_seqlens = torch.tensor([hist_len], device="cuda") + + af = build_incremental_append_arbitrary_func( + history_seqlens, hist_len, current_step=0, + beam_widths=[2], ancestor_positions=None, + ) + + af_dense = arbitrary_func_to_dense(af, hist_len, hist_len) + expected = torch.tril(torch.ones(hist_len, hist_len, dtype=torch.bool, device="cuda")) + assert torch.equal(af_dense[0], expected) + + def test_history_region_matches(self): + """History region should be identical between dense and arbitrary_func.""" + B = 1 + hist_len = 6 + max_hist = hist_len + beam_widths = [2, 2] + history_seqlens = torch.tensor([hist_len], device="cuda") + + bs = _run_beam_search(B, 2, 2, [10, 10]) + anc_pos = bs.get_ancestor_positions(history_seqlens) + + dense = build_incremental_append_dense_mask( + history_seqlens, max_hist, current_step=2, + beam_widths=beam_widths, ancestor_positions=anc_pos, + ) + af = build_incremental_append_arbitrary_func( + history_seqlens, max_hist, current_step=2, + beam_widths=beam_widths, ancestor_positions=anc_pos, + ) + + N = max_hist + sum(beam_widths[:2]) + af_dense = arbitrary_func_to_dense(af, N, N) + + # History region [0:hist_len, 0:hist_len] should match + assert torch.equal( + dense[0, :hist_len, :hist_len], + af_dense[0, :hist_len, :hist_len], + ) From e957c1cdbc5a33f3200394db6a2995ace7e7158b Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 31 Mar 2026 07:49:14 +0000 Subject: [PATCH 02/10] Method A with partial tests. --- examples/sid_gr/model/gpt_model.py | 71 ++-- .../sid_gr/model/jagged_flash_attn_block.py | 336 ++++++++++++++++++ .../tests/test_generate_method_a_e2e.py | 262 ++++++++++++++ .../tests/test_jagged_flash_attn_block.py | 259 ++++++++++++++ 4 files changed, 898 insertions(+), 30 deletions(-) create mode 100644 examples/sid_gr/model/jagged_flash_attn_block.py create mode 100644 examples/sid_gr/tests/test_generate_method_a_e2e.py create mode 100644 examples/sid_gr/tests/test_jagged_flash_attn_block.py diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index fbf2243b6..d9c15feac 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -42,6 +42,7 @@ padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) +from .jagged_flash_attn_block import JaggedFlashAttnBlock def _padding_to_dense_and_transpose( @@ -791,9 +792,8 @@ def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: tree-shaped arbitrary mask: each beam token sees all history plus its ancestor chain, but not tokens from other beam branches. - This method uses the existing dense-mask + pad-to-dense decoder path - for correctness validation. The arbitrary_func tensor is also built - for future use with jagged Flash Attention. + This method uses JaggedFlashAttnBlock with arbitrary_func mask encoding + and jiayus's Flash Attention kernel. """ # 0. prepare history+BOS embeddings (identical to generate()) ( @@ -807,6 +807,21 @@ def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: input_offsets = input_offsets[: batch_size + 1] history_seqlens = torch.diff(input_offsets) # [B] + # Lazy-init the JaggedFlashAttnBlock (owns its own weights) + if not hasattr(self, "_jagged_fa_decoder"): + self._jagged_fa_decoder = JaggedFlashAttnBlock( + num_layers=self.config.num_layers, + hidden_size=self.embedding_dim, + num_attention_heads=self.config.num_attention_heads, + ffn_hidden_size=self.config.ffn_hidden_size, + layernorm_epsilon=getattr( + self.config, "layernorm_epsilon", 1e-5 + ), + ).to( + device=history_embeddings.device, + dtype=self._training_dtype, + ) + self.beam_search.reset() # physical_step_codes[s] stores the codes selected at step s @@ -879,48 +894,45 @@ def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: ) topk_prev_step = self.beam_search.beam_widths[i - 1] - # 2. Build the Method A attention mask. - # ancestor_positions covers the beams from the latest propagate. - # At step i, the sequence has tokens from steps 0..i-1. - # ancestor_positions has shape [B, topk, i] covering steps 0..i-1. - # None when i == 0 (no beams exist yet). + # 2. Build the arbitrary_func mask for Method A layout. ancestor_positions = self.beam_search.get_ancestor_positions( history_seqlens ) - - dense_mask = build_incremental_append_dense_mask( + arbitrary_func = build_incremental_append_arbitrary_func( history_seqlens, input_max_seqlen, current_step=i, beam_widths=self.beam_search.beam_widths, ancestor_positions=ancestor_positions, ) - attention_mask = ~dense_mask.unsqueeze(1) # [B, 1, N, N] - # Also build arbitrary_func for future jagged FA integration - _arbitrary_func = build_incremental_append_arbitrary_func( - history_seqlens, - input_max_seqlen, - current_step=i, - beam_widths=self.beam_search.beam_widths, - ancestor_positions=ancestor_positions, - ) # stored for debugging / future use + # 3. Pad jagged → dense [B, S, D], run through JaggedFlashAttnBlock + padded_input = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=cated_hidden_states, + offsets=[cated_offsets], + max_lengths=[cated_max_seqlen], + padding_value=0.0, + ) + .view(batch_size, cated_max_seqlen, -1) + .to(self._training_dtype) + ) # [B, S, D] - # 3. Decoder step (pad-to-dense path) - jagged_output_hidden_states = self.decoder_step( - cated_hidden_states, - cated_offsets, - cated_max_seqlen, - attention_mask=attention_mask, - padding_to_dense=True, - add_bos_to_history=False, - ) + padded_output = self._jagged_fa_decoder( + padded_input, + arbitrary_func=arbitrary_func if i > 0 else None, + seqlen=cated_max_seqlen, + ) # [B, S, D] + + # Unpad dense → jagged + jagged_output_hidden_states = torch.ops.fbgemm.dense_to_jagged( + padded_output, [cated_offsets] + )[0] # 4. Extract the logit-producing tokens from the output. # At step 0: the last 1 token per sample (BOS). # At step i > 0: the last beam_widths[i-1] tokens (step i-1 beams). if i == 0: - # BOS is the last token in history; extract 1 per sample bos_offsets = torch.arange( 0, batch_size + 1, @@ -937,7 +949,6 @@ def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: batch_size, 1, self.embedding_dim ) else: - # Last beam_widths[i-1] tokens per sample logit_token_count = topk_prev_step logit_offsets = torch.arange( 0, diff --git a/examples/sid_gr/model/jagged_flash_attn_block.py b/examples/sid_gr/model/jagged_flash_attn_block.py new file mode 100644 index 000000000..4dd7c4ce8 --- /dev/null +++ b/examples/sid_gr/model/jagged_flash_attn_block.py @@ -0,0 +1,336 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +JaggedFlashAttnBlock: a self-contained GPT Transformer block that uses +jiayus's Flash Attention with arbitrary_func mask encoding. + +This replaces Megatron-Core's TransformerBlock for inference with +Method A (Incremental Append) beam search. + +Architecture per layer (standard pre-norm GPT): + Input → LayerNorm → QKV Projection → Flash Attention (arbitrary mask) + → Output Projection → Residual + → LayerNorm → FFN → Residual → Output + +Reference: examples/hstu/modules/native_hstu_layer.py +""" +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# flash_attn imports are deferred to runtime (inside functions / __init__) +# so that the module can be imported even without flash_attn installed. + + +def build_block_sparsity( + arbitrary_func: torch.Tensor, + seqlen_q: int, + seqlen_k: int, + headdim: int, +) -> Tuple[Optional[object], Optional[object]]: + """ + Build forward (Q2K) and backward (K2Q) block sparsity indices from an + arbitrary_func tensor. + + Returns (linear_k, linear_q) — either can be None if the CUDA extension + is not available (falls back to dense attention). + """ + try: + import create_block_mask_cuda + from flash_attn.cute.block_sparsity import LinearBlockSparseTensorsTorch + from flash_attn.utils.tile_size import get_tile_sizes_by_backend, get_arch + except ImportError: + return None, None + + arch = get_arch() + fwd_q_block, fwd_kv_block = get_tile_sizes_by_backend( + backend="dsl", + pass_type="forward", + arch=arch, + headdim=headdim, + is_causal=False, + is_local=False, + is_arbitrary=True, + ) + bwd_q_block, bwd_kv_block = get_tile_sizes_by_backend( + backend="dsl", + pass_type="backward", + arch=arch, + headdim=headdim, + is_causal=False, + is_local=False, + is_arbitrary=True, + ) + + (k_cnt, k_off, k_idx, k_fcnt, k_foff, k_fidx) = ( + create_block_mask_cuda.create_q2k_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=fwd_q_block, + KV_BLOCK_SIZE=fwd_kv_block, + check_q_boundary=True, + ) + ) + linear_k = LinearBlockSparseTensorsTorch( + mask_block_cnt=k_cnt, + mask_block_offset=k_off, + mask_block_idx=k_idx, + full_block_cnt=k_fcnt, + full_block_offset=k_foff, + full_block_idx=k_fidx, + ) + + (q_cnt, q_off, q_idx, q_fcnt, q_foff, q_fidx) = ( + create_block_mask_cuda.create_k2q_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=bwd_q_block, + KV_BLOCK_SIZE=bwd_kv_block, + ) + ) + linear_q = LinearBlockSparseTensorsTorch( + mask_block_cnt=q_cnt, + mask_block_offset=q_off, + mask_block_idx=q_idx, + full_block_cnt=q_fcnt, + full_block_offset=q_foff, + full_block_idx=q_fidx, + ) + + return linear_k, linear_q + + +class JaggedGPTLayer(nn.Module): + """ + One Transformer layer with jagged Flash Attention. + + Pre-norm GPT structure: + x = x + Attn(LayerNorm(x)) + x = x + FFN(LayerNorm(x)) + + Q/K/V are produced by a single fused linear (same pattern as HSTU's + ``linear_uvqk``). Flash Attention is called with arbitrary_func for + tree-shaped beam search masks. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.ffn_hidden_size = ffn_hidden_size + + # --- Attention sub-layers --- + self.input_layernorm = nn.LayerNorm( + hidden_size, eps=layernorm_epsilon + ) + # Fused QKV projection: hidden_size → 3 * hidden_size + self.linear_qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) + # Output projection after attention + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.attn_dropout = nn.Dropout(hidden_dropout) + + # --- FFN sub-layers --- + self.pre_mlp_layernorm = nn.LayerNorm( + hidden_size, eps=layernorm_epsilon + ) + self.mlp_fc1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.mlp_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + self.mlp_dropout = nn.Dropout(hidden_dropout) + + if activation == "gelu": + self.activation_fn = F.gelu + elif activation == "silu": + self.activation_fn = F.silu + else: + self.activation_fn = F.gelu + + def forward( + self, + hidden_states: torch.Tensor, + arbitrary_func: Optional[torch.Tensor] = None, + linear_k: Optional[object] = None, + linear_q: Optional[object] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch, seqlen, hidden_size] + arbitrary_func: [batch, 1, n_func, seqlen+256] int32 mask encoding. + linear_k: forward block sparsity (Q2K). + linear_q: backward block sparsity (K2Q). + + Returns: + hidden_states: [batch, seqlen, hidden_size] + """ + # ---- Attention block ---- + residual = hidden_states + x = self.input_layernorm(hidden_states) + + # QKV projection: [B, S, H] → [B, S, 3*H] + qkv = self.linear_qkv(x) + # Reshape to [B, S, 3, num_heads, head_dim] and unbind + B, S, _ = qkv.shape + qkv = qkv.view(B, S, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) # each [B, S, num_heads, head_dim] + + # Flash Attention requires fp16/bf16 inputs + from flash_attn.cute.interface import flash_attn_func + + input_dtype = q.dtype + if q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16() + + if arbitrary_func is not None: + attn_out, _ = flash_attn_func( + q, + k, + v, + softmax_scale=self.head_dim ** (-0.5), + causal=False, + arbitrary=True, + linear_k_block_sparse_tensors=linear_k, + linear_q_block_sparse_tensors=linear_q, + aux_tensors=[arbitrary_func], + ) + else: + attn_out, _ = flash_attn_func( + q, + k, + v, + softmax_scale=self.head_dim ** (-0.5), + causal=True, + ) + + if attn_out.dtype != input_dtype: + attn_out = attn_out.to(input_dtype) + + # attn_out: [B, S, num_heads, head_dim] → [B, S, hidden_size] + attn_out = attn_out.reshape(B, S, self.hidden_size) + attn_out = self.linear_proj(attn_out) + attn_out = self.attn_dropout(attn_out) + hidden_states = residual + attn_out + + # ---- FFN block ---- + residual = hidden_states + x = self.pre_mlp_layernorm(hidden_states) + x = self.mlp_fc1(x) + x = self.activation_fn(x) + x = self.mlp_fc2(x) + x = self.mlp_dropout(x) + hidden_states = residual + x + + return hidden_states + + +class JaggedFlashAttnBlock(nn.Module): + """ + A stack of JaggedGPTLayers — the GPT decoder block using jiayus's + Flash Attention with arbitrary_func mask encoding. + + This module owns its own weights (not shared with Megatron-Core). + It is used in place of Megatron's TransformerBlock for inference + with Method A beam search. + + Usage:: + + block = JaggedFlashAttnBlock( + num_layers=4, + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + ) + # padded input: [B, S, D] + output = block(hidden_states, arbitrary_func=af, seqlen=S) + """ + + def __init__( + self, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_attention_heads + self.layers = nn.ModuleList( + [ + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + activation=activation, + ) + for _ in range(num_layers) + ] + ) + self.final_layernorm = nn.LayerNorm( + hidden_size, eps=layernorm_epsilon + ) + + def forward( + self, + hidden_states: torch.Tensor, + arbitrary_func: Optional[torch.Tensor] = None, + seqlen: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch, seqlen, hidden_size] padded input. + arbitrary_func: [batch, 1, n_func, seqlen+256] int32 mask tensor. + If None, uses standard causal attention. + seqlen: sequence length (used for block sparsity construction). + If None, inferred from hidden_states.shape[1]. + + Returns: + hidden_states: [batch, seqlen, hidden_size] + """ + if seqlen is None: + seqlen = hidden_states.shape[1] + + # Build block sparsity from arbitrary_func (once per forward, shared by all layers) + linear_k, linear_q = None, None + if arbitrary_func is not None: + linear_k, linear_q = build_block_sparsity( + arbitrary_func, seqlen, seqlen, self.head_dim + ) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + arbitrary_func=arbitrary_func, + linear_k=linear_k, + linear_q=linear_q, + ) + + hidden_states = self.final_layernorm(hidden_states) + return hidden_states diff --git a/examples/sid_gr/tests/test_generate_method_a_e2e.py b/examples/sid_gr/tests/test_generate_method_a_e2e.py new file mode 100644 index 000000000..3e3fc7e98 --- /dev/null +++ b/examples/sid_gr/tests/test_generate_method_a_e2e.py @@ -0,0 +1,262 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +End-to-end test for generate_method_a(). + +Constructs a full SIDGRModel, runs generate_method_a() with a random batch, +and verifies: + 1. It runs without errors (smoke test) + 2. Output shapes are correct + 3. Generated SIDs are within valid codebook ranges +""" +from typing import List + +import pytest +import torch + +try: + import commons.utils as init + from commons.checkpoint import get_unwrapped_module + from commons.datasets.gpt_sid_batch import FeatureConfig, GPTSIDBatch + from commons.modules.embedding import ShardedEmbeddingConfig + from commons.ops.length_to_offsets import length_to_complete_offsets + from tests.test_utils import create_sid_gr_model_and_optimizer + HAS_SIDGR_DEPS = True +except ImportError as e: + HAS_SIDGR_DEPS = False + _SIDGR_IMPORT_ERR = str(e) + +try: + from flash_attn.cute.interface import flash_attn_func # noqa: F401 + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + +_SKIP_REASON = ( + "flash_attn not installed" if not HAS_FLASH_ATTN + else f"SID-GR deps unavailable: {_SIDGR_IMPORT_ERR}" if not HAS_SIDGR_DEPS + else None +) +_SHOULD_SKIP = not (HAS_FLASH_ATTN and HAS_SIDGR_DEPS) + + +def _generate_batch( + batchsize: int, + max_history_length: int, + codebook_sizes: List[int], + history_feature_name: str, + candidate_feature_name: str, +) -> GPTSIDBatch: + num_hierarchies = len(codebook_sizes) + codebook_sizes_t = torch.tensor(codebook_sizes) + cum_sum = length_to_complete_offsets(codebook_sizes_t) + max_item_ids = cum_sum[1:] + min_item_ids = cum_sum[:-1] + raw_hist_names = [f"hist_sid_{i}" for i in range(num_hierarchies)] + raw_cand_names = [f"cand_sid_{i}" for i in range(num_hierarchies)] + feature_configs = [ + FeatureConfig( + feature_names=raw_hist_names, + max_item_ids=max_item_ids, + min_item_ids=min_item_ids, + max_sequence_length=max_history_length, + is_jagged=True, + ), + FeatureConfig( + feature_names=raw_cand_names, + max_item_ids=max_item_ids, + min_item_ids=min_item_ids, + max_sequence_length=1, + is_jagged=False, + ), + ] + return GPTSIDBatch.random( + batch_size=batchsize, + feature_configs=feature_configs, + raw_hist_sid_names=raw_hist_names, + raw_cand_sid_names=raw_cand_names, + combined_history_feature_name=history_feature_name, + combined_candidate_feature_name=candidate_feature_name, + contextual_feature_names=[], + device=torch.cuda.current_device(), + ) + + +@pytest.mark.skipif(_SHOULD_SKIP, reason=_SKIP_REASON or "") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_attention_heads", [4]) +@pytest.mark.parametrize("kv_channels", [64]) +@pytest.mark.parametrize("num_layers", [1]) +@pytest.mark.parametrize("max_history_length", [64]) +@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128]]) +@pytest.mark.parametrize("batchsize", [4]) +def test_generate_method_a_smoke( + dtype, + hidden_size, + num_attention_heads, + kv_channels, + num_layers, + max_history_length, + codebook_sizes, + batchsize, +): + """generate_method_a() runs end-to-end without errors.""" + num_hierarchies = len(codebook_sizes) + init.initialize_distributed() + init.initialize_model_parallel(1) + init.set_random_seed(42) + + hist_name = "hist_sids" + cand_name = "cand_sids" + codebook_embedding_config = ShardedEmbeddingConfig( + feature_names=[hist_name, cand_name], + table_name="codebook", + vocab_size=sum(codebook_sizes), + dim=hidden_size, + sharding_type="data_parallel", + ) + + with init.auto_destroy_global_state(): + model, optimizer = create_sid_gr_model_and_optimizer( + dtype=dtype, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + num_layers=num_layers, + num_hierarchies=num_hierarchies, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + ) + optimizer.reload_model_params() + model_unwrapped = get_unwrapped_module(model) + model_unwrapped.eval() + + batch = _generate_batch( + batchsize=batchsize, + max_history_length=max_history_length, + codebook_sizes=codebook_sizes, + history_feature_name=hist_name, + candidate_feature_name=cand_name, + ) + batch.to(torch.cuda.current_device()) + + generated_sids, log_probs = model_unwrapped.generate_method_a(batch) + + # Shape checks + actual_bs = batch.actual_batch_size + top_k = model_unwrapped.top_k_for_generation + assert generated_sids.shape == ( + actual_bs, + top_k, + num_hierarchies, + ), f"Expected ({actual_bs}, {top_k}, {num_hierarchies}), got {generated_sids.shape}" + + assert log_probs.shape == ( + actual_bs, + top_k, + ), f"Expected ({actual_bs}, {top_k}), got {log_probs.shape}" + + # SIDs should be within codebook ranges + for h in range(num_hierarchies): + assert torch.all(generated_sids[:, :, h] >= 0) + assert torch.all(generated_sids[:, :, h] < codebook_sizes[h]) + + # Log probs should be negative (log of probabilities) + assert torch.all(log_probs <= 0) + + +@pytest.mark.skipif(_SHOULD_SKIP, reason=_SKIP_REASON or "") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [256]) +@pytest.mark.parametrize("num_attention_heads", [4]) +@pytest.mark.parametrize("kv_channels", [64]) +@pytest.mark.parametrize("num_layers", [1]) +@pytest.mark.parametrize("max_history_length", [64]) +@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128]]) +def test_generate_method_a_vs_original( + dtype, + hidden_size, + num_attention_heads, + kv_channels, + num_layers, + max_history_length, + codebook_sizes, +): + """ + generate_method_a() and generate() should produce valid outputs for the + same input. They use different decoders (JaggedFlashAttnBlock vs Megatron) + with different weights, so outputs won't be identical, but both should + produce valid SIDs and log_probs. + """ + num_hierarchies = len(codebook_sizes) + init.initialize_distributed() + init.initialize_model_parallel(1) + init.set_random_seed(42) + + hist_name = "hist_sids" + cand_name = "cand_sids" + codebook_embedding_config = ShardedEmbeddingConfig( + feature_names=[hist_name, cand_name], + table_name="codebook", + vocab_size=sum(codebook_sizes), + dim=hidden_size, + sharding_type="data_parallel", + ) + + with init.auto_destroy_global_state(): + model, optimizer = create_sid_gr_model_and_optimizer( + dtype=dtype, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + kv_channels=kv_channels, + num_layers=num_layers, + num_hierarchies=num_hierarchies, + codebook_embedding_config=codebook_embedding_config, + codebook_sizes=codebook_sizes, + ) + optimizer.reload_model_params() + model_unwrapped = get_unwrapped_module(model) + model_unwrapped.eval() + + batch = _generate_batch( + batchsize=4, + max_history_length=max_history_length, + codebook_sizes=codebook_sizes, + history_feature_name=hist_name, + candidate_feature_name=cand_name, + ) + batch.to(torch.cuda.current_device()) + + # Run original generate + orig_sids, orig_probs = model_unwrapped.generate(batch) + + # Run Method A generate + method_a_sids, method_a_probs = model_unwrapped.generate_method_a(batch) + + # Both should have valid shapes + actual_bs = batch.actual_batch_size + top_k = model_unwrapped.top_k_for_generation + assert orig_sids.shape == method_a_sids.shape == ( + actual_bs, top_k, num_hierarchies + ) + assert orig_probs.shape == method_a_probs.shape == (actual_bs, top_k) + + # Both should produce valid SIDs + for h in range(num_hierarchies): + assert torch.all(orig_sids[:, :, h] >= 0) + assert torch.all(orig_sids[:, :, h] < codebook_sizes[h]) + assert torch.all(method_a_sids[:, :, h] >= 0) + assert torch.all(method_a_sids[:, :, h] < codebook_sizes[h]) diff --git a/examples/sid_gr/tests/test_jagged_flash_attn_block.py b/examples/sid_gr/tests/test_jagged_flash_attn_block.py new file mode 100644 index 000000000..ca89b3fef --- /dev/null +++ b/examples/sid_gr/tests/test_jagged_flash_attn_block.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for JaggedFlashAttnBlock / JaggedGPTLayer. + +Test strategy (same pattern as HSTU's test_hstu_layer.py): + 1. Build a reference implementation using PyTorch's scaled_dot_product_attention + 2. Build JaggedGPTLayer with the same weights + 3. Run the same input, compare outputs + +Tests: + - Smoke: forward pass runs without error, output shape correct + - Causal: causal mask produces correct output vs PyTorch reference + - Backward: gradients flow correctly +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from jagged_flash_attn_block import ( + JaggedFlashAttnBlock, + JaggedGPTLayer, +) +sys.path.pop(0) + +try: + from flash_attn.cute.interface import flash_attn_func # noqa: F401 + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + + +# --------------------------------------------------------------------------- +# PyTorch reference: single GPT layer with standard attention +# --------------------------------------------------------------------------- +class ReferenceGPTLayer(nn.Module): + """Minimal GPT layer using PyTorch's SDPA for correctness comparison.""" + + def __init__(self, hidden_size, num_heads, ffn_hidden_size, eps=1e-5): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + self.input_layernorm = nn.LayerNorm(hidden_size, eps=eps) + self.linear_qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.pre_mlp_layernorm = nn.LayerNorm(hidden_size, eps=eps) + self.mlp_fc1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.mlp_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, hidden_states, attn_mask=None, is_causal=False): + residual = hidden_states + x = self.input_layernorm(hidden_states) + B, S, _ = x.shape + + qkv = self.linear_qkv(x).view(B, S, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) + q = q.transpose(1, 2) # [B, H, S, D] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=is_causal + ) + attn_out = attn_out.transpose(1, 2).reshape(B, S, self.hidden_size) + attn_out = self.linear_proj(attn_out) + hidden_states = residual + attn_out + + residual = hidden_states + x = self.pre_mlp_layernorm(hidden_states) + x = self.mlp_fc1(x) + x = F.gelu(x) + x = self.mlp_fc2(x) + hidden_states = residual + x + + return hidden_states + + +def _copy_weights(src: nn.Module, dst: nn.Module): + """Copy all matching weights from src to dst.""" + dst.load_state_dict(src.state_dict(), strict=False) + src_dict = dict(src.named_parameters()) + for name, param in dst.named_parameters(): + if name in src_dict: + param.data.copy_(src_dict[name].data) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerSmoke: + """Basic smoke tests that don't require flash_attn.""" + + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("seqlen", [16, 64]) + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4), (512, 8)]) + def test_forward_shape(self, batch_size, seqlen, hidden_size, num_heads): + """Output shape should match input shape.""" + ffn_size = hidden_size * 4 + layer = JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + ).cuda() + + x = torch.randn(batch_size, seqlen, hidden_size, device="cuda") + out = layer(x) + assert out.shape == (batch_size, seqlen, hidden_size) + + def test_block_forward_shape(self): + """JaggedFlashAttnBlock stacks layers correctly.""" + block = JaggedFlashAttnBlock( + num_layers=2, + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + ).cuda().bfloat16() + + x = torch.randn(2, 32, 256, device="cuda", dtype=torch.bfloat16) + out = block(x) + assert out.shape == (2, 32, 256) + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerCorrectness: + """Compare JaggedGPTLayer (FA) against PyTorch reference.""" + + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4)]) + @pytest.mark.parametrize("seqlen", [16, 32]) + def test_causal_matches_reference(self, hidden_size, num_heads, seqlen): + """ + With causal=True (no arbitrary_func), JaggedGPTLayer should produce + the same output as the PyTorch reference (within bf16 precision). + """ + B = 2 + ffn_size = hidden_size * 4 + torch.manual_seed(42) + + ref_layer = ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + test_layer = JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ).cuda().bfloat16() + + test_layer.load_state_dict(ref_layer.state_dict()) + + x = torch.randn(B, seqlen, hidden_size, device="cuda", dtype=torch.bfloat16) + + with torch.no_grad(): + ref_out = ref_layer(x, is_causal=True) + test_out = test_layer(x, arbitrary_func=None) + + torch.testing.assert_close(test_out, ref_out, atol=5e-2, rtol=5e-2) + + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4)]) + def test_arbitrary_causal_matches_standard_causal(self, hidden_size, num_heads): + """ + An arbitrary_func encoding a causal mask should produce the same + result as the built-in causal=True path. + """ + B, S = 1, 32 + ffn_size = hidden_size * 4 + torch.manual_seed(42) + + layer = JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ).cuda().bfloat16() + + x = torch.randn(B, S, hidden_size, device="cuda", dtype=torch.bfloat16) + + # Build causal arbitrary_func: F0[i] = i+1 + n_func = 1 + af = torch.zeros(B, 1, n_func, S + 256, dtype=torch.int32, device="cuda") + for i in range(S): + af[:, :, 0, i] = i + 1 + + with torch.no_grad(): + out_causal = layer(x, arbitrary_func=None) + out_arb = layer(x, arbitrary_func=af) + + torch.testing.assert_close(out_arb, out_causal, atol=5e-2, rtol=5e-2) + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerBackward: + """Verify gradients flow correctly.""" + + def test_backward_runs(self): + """Forward + backward should not error.""" + layer = JaggedGPTLayer( + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + hidden_dropout=0.0, + ).cuda().bfloat16() + + x = torch.randn(2, 16, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True) + out = layer(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_backward_gradient_correctness(self): + """ + Gradient of JaggedGPTLayer (causal) should match the PyTorch reference. + """ + hidden_size, num_heads = 256, 4 + ffn_size = 1024 + B, S = 2, 16 + torch.manual_seed(42) + + ref_layer = ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + test_layer = JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ).cuda().bfloat16() + test_layer.load_state_dict(ref_layer.state_dict()) + + x_ref = torch.randn(B, S, hidden_size, device="cuda", dtype=torch.bfloat16, requires_grad=True) + x_test = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_layer(x_ref, is_causal=True) + test_out = test_layer(x_test, arbitrary_func=None) + + dout = torch.randn_like(ref_out) + ref_out.backward(dout) + test_out.backward(dout) + + torch.testing.assert_close( + x_test.grad, x_ref.grad, atol=5e-2, rtol=5e-2 + ) From e74a13d302e75291928be8232d523756ca840a75 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 31 Mar 2026 09:20:38 +0000 Subject: [PATCH 03/10] Remove method A and keep JaggedTransformerBlock worked. --- examples/sid_gr/beam_search/beam_search.py | 78 ---- examples/sid_gr/model/attention_mask.py | 226 ++++-------- examples/sid_gr/model/gpt_model.py | 341 +++++------------- .../test_dense_mask_to_arbitrary_func.py | 111 ++++++ .../tests/test_generate_method_a_e2e.py | 262 -------------- .../sid_gr/tests/test_method_a_generate.py | 330 ----------------- 6 files changed, 261 insertions(+), 1087 deletions(-) create mode 100644 examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py delete mode 100644 examples/sid_gr/tests/test_generate_method_a_e2e.py delete mode 100644 examples/sid_gr/tests/test_method_a_generate.py diff --git a/examples/sid_gr/beam_search/beam_search.py b/examples/sid_gr/beam_search/beam_search.py index 0f8c8f55a..e22b4bf43 100644 --- a/examples/sid_gr/beam_search/beam_search.py +++ b/examples/sid_gr/beam_search/beam_search.py @@ -46,8 +46,6 @@ def __init__( self.history_accumulate_topk_probs: List[torch.Tensor] = [] self.history_probs: List[torch.Tensor] = [] - # parent beam indices per step for Method A ancestor tracking - self.parent_indices: List[torch.Tensor] = [] self.reset() def propagate( @@ -116,7 +114,6 @@ def propagate( self.history_topk_sids.append(generated_sids) self.history_accumulate_topk_probs.append(torch.exp(topk_probs)) self.history_probs.append(torch.exp(log_probs_this_step)) - self.parent_indices.append(last_step_indices) self.generated_sids = generated_sids self.accumulated_log_probs = topk_probs self.step += 1 @@ -130,7 +127,6 @@ def reset(self): self.history_topk_sids = [] self.history_accumulate_topk_probs = [] self.history_probs = [] - self.parent_indices = [] def get_sids( self, @@ -148,80 +144,6 @@ def get_sids( else: raise ValueError(f"Step {step} is not valid, current step is {self.step}") - def get_ancestor_positions( - self, - hist_len: torch.Tensor, - ) -> torch.Tensor: - """ - Compute physical positions of each current beam's ancestors in a - Method A (Incremental Append) sequence layout. - - In Method A, the sequence is laid out as: - [history..., step0_beam0, step0_beam1, ..., step1_beam0, step1_beam1, ...] - - The physical position of the token generated at step s for beam b is: - hist_len + sum(beam_widths[:s]) + b - - This method traces back through parent_indices to find, for each - current beam, its ancestor beam index at every previous step. - - Args: - hist_len: [B] or scalar, the history length per sample (including BOS). - - Returns: - ancestor_positions: [B, topk_current, step] int64 tensor. - ancestor_positions[b, k, s] = physical position of beam k's - ancestor at step s in the Method A sequence. - Returns None if step == 0 (no ancestors yet). - """ - if self.step == 0: - return None - - batch_size = self.parent_indices[0].shape[0] - topk_current = self.parent_indices[-1].shape[1] - device = self.parent_indices[0].device - - # beam_index_at_step[s] has shape [B, topk_current]: - # the beam index that each current beam occupied at step s - beam_index_at_step: List[torch.Tensor] = [None] * self.step - - # Start from the current beams and trace backwards - current_beam_idx = torch.arange( - topk_current, device=device - ).unsqueeze(0).expand(batch_size, -1) - beam_index_at_step[self.step - 1] = current_beam_idx - - for s in range(self.step - 1, 0, -1): - parent_idx = self.parent_indices[s] # [B, topk_at_step_s] - current_beam_idx = torch.gather( - parent_idx, dim=1, index=current_beam_idx - ) - beam_index_at_step[s - 1] = current_beam_idx - - # Convert beam indices to physical positions - # physical_pos(step=s, beam=b) = hist_len + sum(beam_widths[:s]) + b - if not isinstance(hist_len, torch.Tensor): - hist_len = torch.tensor( - [hist_len] * batch_size, device=device, dtype=torch.long - ) - hist_len = hist_len.unsqueeze(1) # [B, 1] - - step_offsets = [] - cumulative = 0 - for s in range(self.step): - step_offsets.append(cumulative) - cumulative += self.beam_widths[s] - - ancestor_positions = torch.stack( - [ - hist_len + step_offsets[s] + beam_index_at_step[s] - for s in range(self.step) - ], - dim=-1, - ) # [B, topk_current, step] - - return ancestor_positions - def generate_valid_mask(self) -> torch.Tensor: """ update the valid mask between current step and previous step, diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py index 5ef1301e4..ac5d01897 100644 --- a/examples/sid_gr/model/attention_mask.py +++ b/examples/sid_gr/model/attention_mask.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - import torch @@ -209,179 +207,81 @@ def padded_target_aware_causal_mask( return ~valid_attn_mask -def build_incremental_append_arbitrary_func( - history_seqlens: torch.Tensor, - max_history_seqlen: int, - current_step: int, - beam_widths: List[int], - ancestor_positions: torch.Tensor, +def dense_mask_to_arbitrary_func( + valid_mask: torch.Tensor, + seqlen: int, + padding: int = 256, ) -> torch.Tensor: """ - Build an arbitrary_func tensor for the Method A (Incremental Append) beam - search layout, compatible with flash_attn's interval-based mask encoding. - - Method A sequence layout (at the time of predicting step ``current_step``): - [history..., step0_beams, step1_beams, ..., step_{current_step-1}_beams] + Convert a dense bool attention mask to flash_attn's interval-based + arbitrary_func tensor. - Mask semantics: - - History tokens: causal within history (token i sees [0, i+1)) - - Generated token at (step s, beam b): sees - [0, hist_len) (all history) - plus its ancestor chain encoded via ancestor_positions. - - The last step's tokens (step current_step-1) use precise tree mask. - - Earlier steps' tokens use a causal-like approximation. + For each query position q, the arbitrary_func encodes visible key + positions as a union of intervals: + visible(q) = [0, F0) ∪ [F1, F2) ∪ [F3, F4) ∪ ... Args: - history_seqlens: [B] per-sample history lengths. - max_history_seqlen: max history length (for padding). - current_step: the hierarchy step being predicted (0-based). - The sequence contains tokens from steps 0..current_step-1. - beam_widths: list of beam widths per step. - ancestor_positions: [B, topk_current, current_step] physical positions - of each beam's ancestors at steps 0..current_step-1. - None if current_step == 0. + valid_mask: [B, N, N] or [B, 1, N, N] bool tensor (True = can attend). + seqlen: sequence length N. + padding: extra padding on last dim (FA convention, default 256). Returns: - arbitrary_func: [B, 1, n_func, total_seqlen + 256] int32 tensor. + arbitrary_func: [B, 1, n_func, seqlen + padding] int32 tensor. """ - device = history_seqlens.device - B = history_seqlens.shape[0] - total_generated = sum(beam_widths[:current_step]) - total_seqlen = max_history_seqlen + total_generated - - if current_step == 0: - # Pure causal mask over history, n_func=1 - af = torch.zeros( - B, 1, 1, total_seqlen + 256, dtype=torch.int32, device=device - ) - hist_positions = torch.arange(max_history_seqlen, device=device) - hist_f0 = (hist_positions + 1).unsqueeze(0).expand(B, -1) - hist_f0 = torch.min( - hist_f0, history_seqlens.unsqueeze(1).expand_as(hist_f0) - ) - af[:, 0, 0, :max_history_seqlen] = hist_f0.to(torch.int32) - return af - - topk_current = beam_widths[current_step - 1] - - # n_func must be odd. The last-step tokens need at most - # (current_step + 1) intervals: 1 (history) + current_step (ancestors incl self). - max_intervals = current_step + 1 - n_func = max(2 * max_intervals - 1, 3) - - af = torch.zeros( - B, 1, n_func, total_seqlen + 256, dtype=torch.int32, device=device - ) - - # --- History tokens: causal mask, F0 = pos + 1 --- - hist_positions = torch.arange(max_history_seqlen, device=device) - hist_f0 = (hist_positions + 1).unsqueeze(0).expand(B, -1) - hist_f0 = torch.min( - hist_f0, history_seqlens.unsqueeze(1).expand_as(hist_f0) - ) - af[:, 0, 0, :max_history_seqlen] = hist_f0.to(torch.int32) - - # --- Generated tokens --- - gen_offset = 0 - for s in range(current_step): - bw_s = beam_widths[s] - gen_start = max_history_seqlen + gen_offset - is_last_step = s == current_step - 1 - - for b in range(bw_s): - pos = gen_start + b - - # F0: base interval = [0, hist_len) — see all history - af[:, 0, 0, pos] = history_seqlens.to(torch.int32) - - if is_last_step and ancestor_positions is not None: - # Tree mask: precise ancestor intervals from ancestor_positions. - # ancestor_positions[:, b, :] has positions at steps 0..current_step-1. - # Steps 0..current_step-2 are true ancestors; step current_step-1 is self. - for anc_idx in range(current_step): - anc_pos = ancestor_positions[:, b, anc_idx] # [B] - af[:, 0, 2 * anc_idx + 1, pos] = anc_pos.to(torch.int32) - af[:, 0, 2 * anc_idx + 2, pos] = (anc_pos + 1).to( - torch.int32 - ) - else: - # Earlier steps' tokens: causal over [hist_len, pos+1) - af[:, 0, 1, pos] = max_history_seqlen - af[:, 0, 2, pos] = pos + 1 - - gen_offset += bw_s + if valid_mask.dim() == 4: + valid_mask = valid_mask.squeeze(1) + assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" + + B, N, _ = valid_mask.shape + device = valid_mask.device + + # Detect interval boundaries via transitions + shifted = torch.zeros_like(valid_mask) + shifted[:, :, 1:] = valid_mask[:, :, :-1] + starts = valid_mask & ~shifted # start of each True run + max_intervals = int(starts.sum(dim=-1).max().item()) + n_func = max(2 * max_intervals - 1, 1) + if n_func % 2 == 0: + n_func += 1 + + # When first interval doesn't start at 0, it needs an extra slot. + # Recount: base interval [0, F0) is free only if first run starts at 0. + # Worst case: all intervals need explicit [F_start, F_end) pairs. + # n_func = 2*max_intervals + 1 covers all cases. + n_func = 2 * max_intervals + 1 + if n_func % 2 == 0: + n_func += 1 + + af = torch.zeros(B, 1, n_func, seqlen + padding, dtype=torch.int32, device=device) + + ends_shifted = torch.zeros_like(valid_mask) + ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] + ends = valid_mask & ~ends_shifted # last True position in each run + + for b in range(B): + for q in range(N): + row = valid_mask[b, q] + if not row.any(): + continue + start_pos = starts[b, q].nonzero(as_tuple=False).squeeze(-1) + end_pos = ends[b, q].nonzero(as_tuple=False).squeeze(-1) + 1 + + # F0 encodes [0, F0). If first interval starts at 0, use F0. + # Otherwise F0 stays 0 (empty base interval) and all intervals + # go into the extra slots. + extra_idx = 0 + for iv in range(len(start_pos)): + s, e = start_pos[iv].item(), end_pos[iv].item() + if iv == 0 and s == 0: + af[b, 0, 0, q] = e + else: + af[b, 0, 2 * extra_idx + 1, q] = s + af[b, 0, 2 * extra_idx + 2, q] = e + extra_idx += 1 return af -def build_incremental_append_dense_mask( - history_seqlens: torch.Tensor, - max_history_seqlen: int, - current_step: int, - beam_widths: List[int], - ancestor_positions: torch.Tensor, -) -> torch.Tensor: - """ - Build a dense [B, N, N] bool mask for Method A layout — for testing and - validation against arbitrary_func encoding. - - Args: - current_step: the hierarchy step being predicted (0-based). - The sequence contains tokens from steps 0..current_step-1. - ancestor_positions: [B, topk, current_step] or None. - Covers steps 0..current_step-1. The last entry is the self position. - - Returns: - mask: [B, N, N] bool tensor where True = can attend. - """ - device = history_seqlens.device - B = history_seqlens.shape[0] - total_generated = sum(beam_widths[:current_step]) - N = max_history_seqlen + total_generated - - mask = torch.zeros(B, N, N, dtype=torch.bool, device=device) - - # History region: causal - for b_idx in range(B): - hl = history_seqlens[b_idx].item() - for i in range(hl): - mask[b_idx, i, : i + 1] = True - - if current_step == 0: - return mask - - topk_current = beam_widths[current_step - 1] - - # Generated tokens from earlier steps (steps 0..current_step-2): causal - gen_offset = 0 - for s in range(current_step - 1): - bw_s = beam_widths[s] - gen_start = max_history_seqlen + gen_offset - for b in range(bw_s): - pos = gen_start + b - for b_idx in range(B): - hl = history_seqlens[b_idx].item() - mask[b_idx, pos, :hl] = True - mask[b_idx, pos, max_history_seqlen : pos + 1] = True - gen_offset += bw_s - - # Last step tokens (step current_step-1): precise tree mask - gen_start = max_history_seqlen + gen_offset - for b in range(topk_current): - pos = gen_start + b - for b_idx in range(B): - hl = history_seqlens[b_idx].item() - mask[b_idx, pos, :hl] = True # history - if ancestor_positions is not None: - for anc_step in range(current_step): - anc_pos = ancestor_positions[b_idx, b, anc_step].item() - mask[b_idx, pos, anc_pos] = True - else: - mask[b_idx, pos, pos] = True # self only - - return mask - - if __name__ == "__main__": history_seqlen = torch.tensor([4, 3]).cuda() max_history_seqlen = 6 diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index d9c15feac..f17395428 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -37,12 +37,9 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from .attention_mask import ( - build_incremental_append_arbitrary_func, - build_incremental_append_dense_mask, padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) -from .jagged_flash_attn_block import JaggedFlashAttnBlock def _padding_to_dense_and_transpose( @@ -518,59 +515,113 @@ def decoder_step( attention_mask: Optional[torch.Tensor] = None, padding_to_dense: bool = True, add_bos_to_history: bool = False, + use_flash_attn: bool = False, ) -> torch.Tensor: """ Input and Output are both jagged. attention_mask is used only when padding_to_dense is True. When attention mask is None, we will construct a causal attention mask if padding_to_dense is True. - We now only support dense input. + When use_flash_attn is True, uses JaggedFlashAttnBlock with jiayus's + Flash Attention + arbitrary_func mask encoding instead of Megatron's + TransformerBlock. The dense attention_mask is automatically converted + to arbitrary_func format. """ if add_bos_to_history: assert ( attention_mask is None ), "attention mask should be None when adding bos to history" - # TODO, remove the padding. input_offsets[-1].item() - if padding_to_dense: - decoder_input_hidden_states = _padding_to_dense_and_transpose( - input_hidden_states, + batch_size = input_offsets.size(0) - 1 + + # Build attention_mask if not provided (needed by both paths) + if attention_mask is None and padding_to_dense: + attention_mask = padded_causal_mask_with_optional_bos( input_offsets, input_max_seqlen, + add_bos_to_history=add_bos_to_history, + bos_interval=self._num_hierarchies, ) - packed_seq_params = None - if attention_mask is None: - attention_mask = padded_causal_mask_with_optional_bos( - input_offsets, - input_max_seqlen, - add_bos_to_history=add_bos_to_history, - bos_interval=self._num_hierarchies, + + if use_flash_attn: + from .attention_mask import dense_mask_to_arbitrary_func + from .jagged_flash_attn_block import JaggedFlashAttnBlock + + # Lazy-init the FA decoder block + if not hasattr(self, "_jagged_fa_decoder"): + self._jagged_fa_decoder = JaggedFlashAttnBlock( + num_layers=self.config.num_layers, + hidden_size=self.embedding_dim, + num_attention_heads=self.config.num_attention_heads, + ffn_hidden_size=self.config.ffn_hidden_size, + layernorm_epsilon=getattr( + self.config, "layernorm_epsilon", 1e-5 + ), + ).to( + device=input_hidden_states.device, + dtype=self._training_dtype, ) - else: - # THD still needs batch dimension - # we need to unsqueeze the hidden states to [T, 1, hidden_size] and unsqueeze back after decoder - assert input_hidden_states.dim() == 2, "input_hidden_states should be 2D" - decoder_input_hidden_states = input_hidden_states.unsqueeze(1) - attention_mask = None - packed_seq_params = to_packed_seq_params( - input_offsets, - input_max_seqlen, + + # Pad jagged → dense [B, S, D] + padded_input = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=input_hidden_states, + offsets=[input_offsets], + max_lengths=[input_max_seqlen], + padding_value=0.0, + ) + .view(batch_size, input_max_seqlen, -1) + .to(self._training_dtype) ) - decoder_output_hidden_states = self.decoder( - hidden_states=decoder_input_hidden_states, # input_hidden_states, - attention_mask=attention_mask, - packed_seq_params=packed_seq_params, # we now enforce arbitrary attention mask + dense padding - ) - if padding_to_dense: - output_hidden_states = _transpose_dense_to_jagged( - decoder_output_hidden_states, - input_offsets, - input_max_seqlen, + # Convert dense mask → arbitrary_func + # attention_mask convention: True = masked out → invert for valid mask + valid_mask = ~attention_mask # [B, 1, N, N] + arbitrary_func = dense_mask_to_arbitrary_func( + valid_mask, input_max_seqlen + ) + + padded_output = self._jagged_fa_decoder( + padded_input, + arbitrary_func=arbitrary_func, + seqlen=input_max_seqlen, ) + + # Unpad dense → jagged + output_hidden_states = torch.ops.fbgemm.dense_to_jagged( + padded_output, [input_offsets] + )[0] else: - # remove batch dim if THD - output_hidden_states = decoder_output_hidden_states.squeeze(1) + # Original Megatron path + if padding_to_dense: + decoder_input_hidden_states = _padding_to_dense_and_transpose( + input_hidden_states, + input_offsets, + input_max_seqlen, + ) + packed_seq_params = None + else: + assert input_hidden_states.dim() == 2, "input_hidden_states should be 2D" + decoder_input_hidden_states = input_hidden_states.unsqueeze(1) + attention_mask = None + packed_seq_params = to_packed_seq_params( + input_offsets, + input_max_seqlen, + ) + decoder_output_hidden_states = self.decoder( + hidden_states=decoder_input_hidden_states, + attention_mask=attention_mask, + packed_seq_params=packed_seq_params, + ) + + if padding_to_dense: + output_hidden_states = _transpose_dense_to_jagged( + decoder_output_hidden_states, + input_offsets, + input_max_seqlen, + ) + else: + output_hidden_states = decoder_output_hidden_states.squeeze(1) return output_hidden_states def forward( @@ -780,221 +831,3 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: log_probs = self.beam_search.get_log_probs() return generated_sids, log_probs - @torch.no_grad - def generate_method_a(self, batch: GPTSIDBatch) -> torch.Tensor: - """ - Generate using Method A (Incremental Append) token layout. - - Sequence layout at step i: - [history+BOS, step0_beams, step1_beams, ..., step_{i-1}_beams] - - Tokens are appended in generation order. The attention mask is a - tree-shaped arbitrary mask: each beam token sees all history plus - its ancestor chain, but not tokens from other beam branches. - - This method uses JaggedFlashAttnBlock with arbitrary_func mask encoding - and jiayus's Flash Attention kernel. - """ - # 0. prepare history+BOS embeddings (identical to generate()) - ( - history_embeddings, - input_offsets, - input_max_seqlen, - ) = self._prepare_embeddings( - batch, add_bos_to_history=False, is_generation=True - ) - batch_size = batch.actual_batch_size - input_offsets = input_offsets[: batch_size + 1] - history_seqlens = torch.diff(input_offsets) # [B] - - # Lazy-init the JaggedFlashAttnBlock (owns its own weights) - if not hasattr(self, "_jagged_fa_decoder"): - self._jagged_fa_decoder = JaggedFlashAttnBlock( - num_layers=self.config.num_layers, - hidden_size=self.embedding_dim, - num_attention_heads=self.config.num_attention_heads, - ffn_hidden_size=self.config.ffn_hidden_size, - layernorm_epsilon=getattr( - self.config, "layernorm_epsilon", 1e-5 - ), - ).to( - device=history_embeddings.device, - dtype=self._training_dtype, - ) - - self.beam_search.reset() - - # physical_step_codes[s] stores the codes selected at step s - # shape: [batch_size, beam_widths[s]] - physical_step_codes: List[torch.Tensor] = [] - - for i in range(self._num_hierarchies): - # 1. Build the Method A sequence - if i == 0: - # No generated tokens yet; sequence is just [history+BOS] - cated_hidden_states = history_embeddings - cated_offsets = input_offsets - cated_max_seqlen = input_max_seqlen - topk_prev_step = 1 - else: - # Embed all previous steps' codes in Method A order: - # [step0_beam0, step0_beam1, ..., step1_beam0, ...] - all_gen_codes = torch.cat( - [sc.view(-1) for sc in physical_step_codes] - ) # [total_gen_tokens] - total_gen_per_sample = sum( - self.beam_search.beam_widths[s] for s in range(i) - ) - - gen_sids_kjt = KeyedJaggedTensor.from_lengths_sync( - keys=[ - batch.candidate_feature_name, - batch.history_feature_name, - ], - values=all_gen_codes, - lengths=torch.cat( - [ - torch.full( - (batch_size,), - total_gen_per_sample, - device=all_gen_codes.device, - dtype=torch.long, - ), - torch.zeros( - (batch_size,), - device=all_gen_codes.device, - dtype=torch.long, - ), - ] - ), - ) - gen_embeddings = ( - self._codebooks_collection(gen_sids_kjt)[ - batch.candidate_feature_name - ] - .values() - .to(self._training_dtype) - ) # [batch_size * total_gen_per_sample, D] - - gen_offsets = torch.arange( - 0, - batch_size + 1, - device=input_offsets.device, - dtype=input_offsets.dtype, - ) * total_gen_per_sample - - ( - cated_hidden_states, - cated_offsets, - cated_max_seqlen, - ) = self._concat_jagged( - [history_embeddings, gen_embeddings], - [input_offsets, gen_offsets], - [input_max_seqlen, total_gen_per_sample], - ) - topk_prev_step = self.beam_search.beam_widths[i - 1] - - # 2. Build the arbitrary_func mask for Method A layout. - ancestor_positions = self.beam_search.get_ancestor_positions( - history_seqlens - ) - arbitrary_func = build_incremental_append_arbitrary_func( - history_seqlens, - input_max_seqlen, - current_step=i, - beam_widths=self.beam_search.beam_widths, - ancestor_positions=ancestor_positions, - ) - - # 3. Pad jagged → dense [B, S, D], run through JaggedFlashAttnBlock - padded_input = ( - torch.ops.fbgemm.jagged_to_padded_dense( - values=cated_hidden_states, - offsets=[cated_offsets], - max_lengths=[cated_max_seqlen], - padding_value=0.0, - ) - .view(batch_size, cated_max_seqlen, -1) - .to(self._training_dtype) - ) # [B, S, D] - - padded_output = self._jagged_fa_decoder( - padded_input, - arbitrary_func=arbitrary_func if i > 0 else None, - seqlen=cated_max_seqlen, - ) # [B, S, D] - - # Unpad dense → jagged - jagged_output_hidden_states = torch.ops.fbgemm.dense_to_jagged( - padded_output, [cated_offsets] - )[0] - - # 4. Extract the logit-producing tokens from the output. - # At step 0: the last 1 token per sample (BOS). - # At step i > 0: the last beam_widths[i-1] tokens (step i-1 beams). - if i == 0: - bos_offsets = torch.arange( - 0, - batch_size + 1, - device=input_offsets.device, - dtype=input_offsets.dtype, - ).clamp(max=batch_size) - _, candidate_hidden_states = triton_split_2D_jagged( - jagged_output_hidden_states, - max_seq_len=cated_max_seqlen, - offsets_a=cated_offsets - bos_offsets, - offsets_b=bos_offsets, - ) - candidate_hidden_states = candidate_hidden_states.view( - batch_size, 1, self.embedding_dim - ) - else: - logit_token_count = topk_prev_step - logit_offsets = torch.arange( - 0, - batch_size + 1, - device=input_offsets.device, - dtype=input_offsets.dtype, - ) * logit_token_count - - _, candidate_hidden_states = triton_split_2D_jagged( - jagged_output_hidden_states, - max_seq_len=cated_max_seqlen, - offsets_a=cated_offsets - logit_offsets, - offsets_b=logit_offsets, - ) - candidate_hidden_states = candidate_hidden_states.view( - batch_size, topk_prev_step, self.embedding_dim - ) - - # 5. MLP projection → logits → log_softmax - mlp = ( - self._decoder_mlp[i] - if not self.share_lm_head_across_hierarchies - else self._decoder_mlp - ) - tuple_or_tensor: Union[ - Tuple[torch.Tensor, torch.Tensor], torch.Tensor - ] = mlp(candidate_hidden_states) - candidates_logits = ( - tuple_or_tensor[0] - if isinstance(tuple_or_tensor, tuple) - else tuple_or_tensor - ) - probs_this_step: torch.Tensor = torch.nn.functional.log_softmax( - candidates_logits.float(), dim=-1 - ) - - # 6. Beam search propagate - self.beam_search.propagate(probs_this_step) - - # 7. Record the physical step codes for the next iteration. - # history_topk_sids[i][:, :, i] gives the codes selected at step i - # for the beams that survived step i's top-k. - physical_step_codes.append( - self.beam_search.history_topk_sids[-1][:, :, i] - ) - - generated_sids = self.beam_search.get_sids() - log_probs = self.beam_search.get_log_probs() - return generated_sids, log_probs diff --git a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py new file mode 100644 index 000000000..56a5b13b3 --- /dev/null +++ b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for dense_mask_to_arbitrary_func(): verifies that converting a dense +[B,N,N] mask to arbitrary_func interval encoding preserves mask semantics. +""" +import os +import sys + +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from attention_mask import ( + dense_mask_to_arbitrary_func, + padded_target_aware_causal_mask, +) +sys.path.pop(0) + + +def arbitrary_func_to_dense(af, seqlen_q, seqlen_k): + """Expand arbitrary_func back to dense [B, seqlen_q, seqlen_k] bool mask.""" + B, n_func = af.shape[0], af.shape[2] + mask = torch.zeros(B, seqlen_q, seqlen_k, dtype=torch.bool, device=af.device) + kv_idx = torch.arange(seqlen_k, device=af.device) + for b in range(B): + for q in range(seqlen_q): + f0 = af[b, 0, 0, q].item() + row_mask = kv_idx < f0 + for iv in range(n_func // 2): + f_start = af[b, 0, 2 * iv + 1, q].item() + f_end = af[b, 0, 2 * iv + 2, q].item() + row_mask = row_mask | ((kv_idx >= f_start) & (kv_idx < f_end)) + mask[b, q] = row_mask + return mask + + +class TestDenseMaskToArbitraryFunc: + def test_causal_mask(self): + N, B = 16, 1 + valid = torch.tril(torch.ones(B, N, N, dtype=torch.bool, device="cuda")) + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_full_attention(self): + N, B = 8, 2 + valid = torch.ones(B, N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_empty_mask(self): + N, B = 8, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_block_diagonal(self): + N, B = 8, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0, :4, :4] = True + valid[0, 4:, 4:] = True + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + @pytest.mark.parametrize("beam_width", [2, 3]) + @pytest.mark.parametrize("candidate_len", [1, 3]) + def test_target_aware_causal_mask(self, beam_width, candidate_len): + B = 2 + hist_lens = torch.tensor([6, 4], device="cuda") + inverted = padded_target_aware_causal_mask(hist_lens, 6, beam_width, candidate_len) + valid = ~inverted + N = valid.shape[-1] + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid.squeeze(1), arbitrary_func_to_dense(af, N, N)) + + def test_mask_with_gaps(self): + N, B = 10, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0, 5, 0:3] = True + valid[0, 5, 5:7] = True + valid[0, 5, 9] = True + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_4d_input(self): + N, B = 8, 1 + valid_4d = torch.tril(torch.ones(B, 1, N, N, dtype=torch.bool, device="cuda")) + af = dense_mask_to_arbitrary_func(valid_4d, N) + assert torch.equal(valid_4d.squeeze(1), arbitrary_func_to_dense(af, N, N)) + + def test_batch_independence(self): + N, B = 8, 2 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0] = torch.tril(torch.ones(N, N, dtype=torch.bool, device="cuda")) + valid[1] = torch.ones(N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + recon = arbitrary_func_to_dense(af, N, N) + assert torch.equal(valid[0], recon[0]) + assert torch.equal(valid[1], recon[1]) diff --git a/examples/sid_gr/tests/test_generate_method_a_e2e.py b/examples/sid_gr/tests/test_generate_method_a_e2e.py deleted file mode 100644 index 3e3fc7e98..000000000 --- a/examples/sid_gr/tests/test_generate_method_a_e2e.py +++ /dev/null @@ -1,262 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -End-to-end test for generate_method_a(). - -Constructs a full SIDGRModel, runs generate_method_a() with a random batch, -and verifies: - 1. It runs without errors (smoke test) - 2. Output shapes are correct - 3. Generated SIDs are within valid codebook ranges -""" -from typing import List - -import pytest -import torch - -try: - import commons.utils as init - from commons.checkpoint import get_unwrapped_module - from commons.datasets.gpt_sid_batch import FeatureConfig, GPTSIDBatch - from commons.modules.embedding import ShardedEmbeddingConfig - from commons.ops.length_to_offsets import length_to_complete_offsets - from tests.test_utils import create_sid_gr_model_and_optimizer - HAS_SIDGR_DEPS = True -except ImportError as e: - HAS_SIDGR_DEPS = False - _SIDGR_IMPORT_ERR = str(e) - -try: - from flash_attn.cute.interface import flash_attn_func # noqa: F401 - HAS_FLASH_ATTN = True -except ImportError: - HAS_FLASH_ATTN = False - -_SKIP_REASON = ( - "flash_attn not installed" if not HAS_FLASH_ATTN - else f"SID-GR deps unavailable: {_SIDGR_IMPORT_ERR}" if not HAS_SIDGR_DEPS - else None -) -_SHOULD_SKIP = not (HAS_FLASH_ATTN and HAS_SIDGR_DEPS) - - -def _generate_batch( - batchsize: int, - max_history_length: int, - codebook_sizes: List[int], - history_feature_name: str, - candidate_feature_name: str, -) -> GPTSIDBatch: - num_hierarchies = len(codebook_sizes) - codebook_sizes_t = torch.tensor(codebook_sizes) - cum_sum = length_to_complete_offsets(codebook_sizes_t) - max_item_ids = cum_sum[1:] - min_item_ids = cum_sum[:-1] - raw_hist_names = [f"hist_sid_{i}" for i in range(num_hierarchies)] - raw_cand_names = [f"cand_sid_{i}" for i in range(num_hierarchies)] - feature_configs = [ - FeatureConfig( - feature_names=raw_hist_names, - max_item_ids=max_item_ids, - min_item_ids=min_item_ids, - max_sequence_length=max_history_length, - is_jagged=True, - ), - FeatureConfig( - feature_names=raw_cand_names, - max_item_ids=max_item_ids, - min_item_ids=min_item_ids, - max_sequence_length=1, - is_jagged=False, - ), - ] - return GPTSIDBatch.random( - batch_size=batchsize, - feature_configs=feature_configs, - raw_hist_sid_names=raw_hist_names, - raw_cand_sid_names=raw_cand_names, - combined_history_feature_name=history_feature_name, - combined_candidate_feature_name=candidate_feature_name, - contextual_feature_names=[], - device=torch.cuda.current_device(), - ) - - -@pytest.mark.skipif(_SHOULD_SKIP, reason=_SKIP_REASON or "") -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [256]) -@pytest.mark.parametrize("num_attention_heads", [4]) -@pytest.mark.parametrize("kv_channels", [64]) -@pytest.mark.parametrize("num_layers", [1]) -@pytest.mark.parametrize("max_history_length", [64]) -@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128]]) -@pytest.mark.parametrize("batchsize", [4]) -def test_generate_method_a_smoke( - dtype, - hidden_size, - num_attention_heads, - kv_channels, - num_layers, - max_history_length, - codebook_sizes, - batchsize, -): - """generate_method_a() runs end-to-end without errors.""" - num_hierarchies = len(codebook_sizes) - init.initialize_distributed() - init.initialize_model_parallel(1) - init.set_random_seed(42) - - hist_name = "hist_sids" - cand_name = "cand_sids" - codebook_embedding_config = ShardedEmbeddingConfig( - feature_names=[hist_name, cand_name], - table_name="codebook", - vocab_size=sum(codebook_sizes), - dim=hidden_size, - sharding_type="data_parallel", - ) - - with init.auto_destroy_global_state(): - model, optimizer = create_sid_gr_model_and_optimizer( - dtype=dtype, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - kv_channels=kv_channels, - num_layers=num_layers, - num_hierarchies=num_hierarchies, - codebook_embedding_config=codebook_embedding_config, - codebook_sizes=codebook_sizes, - ) - optimizer.reload_model_params() - model_unwrapped = get_unwrapped_module(model) - model_unwrapped.eval() - - batch = _generate_batch( - batchsize=batchsize, - max_history_length=max_history_length, - codebook_sizes=codebook_sizes, - history_feature_name=hist_name, - candidate_feature_name=cand_name, - ) - batch.to(torch.cuda.current_device()) - - generated_sids, log_probs = model_unwrapped.generate_method_a(batch) - - # Shape checks - actual_bs = batch.actual_batch_size - top_k = model_unwrapped.top_k_for_generation - assert generated_sids.shape == ( - actual_bs, - top_k, - num_hierarchies, - ), f"Expected ({actual_bs}, {top_k}, {num_hierarchies}), got {generated_sids.shape}" - - assert log_probs.shape == ( - actual_bs, - top_k, - ), f"Expected ({actual_bs}, {top_k}), got {log_probs.shape}" - - # SIDs should be within codebook ranges - for h in range(num_hierarchies): - assert torch.all(generated_sids[:, :, h] >= 0) - assert torch.all(generated_sids[:, :, h] < codebook_sizes[h]) - - # Log probs should be negative (log of probabilities) - assert torch.all(log_probs <= 0) - - -@pytest.mark.skipif(_SHOULD_SKIP, reason=_SKIP_REASON or "") -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize("hidden_size", [256]) -@pytest.mark.parametrize("num_attention_heads", [4]) -@pytest.mark.parametrize("kv_channels", [64]) -@pytest.mark.parametrize("num_layers", [1]) -@pytest.mark.parametrize("max_history_length", [64]) -@pytest.mark.parametrize("codebook_sizes", [[128, 128, 128]]) -def test_generate_method_a_vs_original( - dtype, - hidden_size, - num_attention_heads, - kv_channels, - num_layers, - max_history_length, - codebook_sizes, -): - """ - generate_method_a() and generate() should produce valid outputs for the - same input. They use different decoders (JaggedFlashAttnBlock vs Megatron) - with different weights, so outputs won't be identical, but both should - produce valid SIDs and log_probs. - """ - num_hierarchies = len(codebook_sizes) - init.initialize_distributed() - init.initialize_model_parallel(1) - init.set_random_seed(42) - - hist_name = "hist_sids" - cand_name = "cand_sids" - codebook_embedding_config = ShardedEmbeddingConfig( - feature_names=[hist_name, cand_name], - table_name="codebook", - vocab_size=sum(codebook_sizes), - dim=hidden_size, - sharding_type="data_parallel", - ) - - with init.auto_destroy_global_state(): - model, optimizer = create_sid_gr_model_and_optimizer( - dtype=dtype, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - kv_channels=kv_channels, - num_layers=num_layers, - num_hierarchies=num_hierarchies, - codebook_embedding_config=codebook_embedding_config, - codebook_sizes=codebook_sizes, - ) - optimizer.reload_model_params() - model_unwrapped = get_unwrapped_module(model) - model_unwrapped.eval() - - batch = _generate_batch( - batchsize=4, - max_history_length=max_history_length, - codebook_sizes=codebook_sizes, - history_feature_name=hist_name, - candidate_feature_name=cand_name, - ) - batch.to(torch.cuda.current_device()) - - # Run original generate - orig_sids, orig_probs = model_unwrapped.generate(batch) - - # Run Method A generate - method_a_sids, method_a_probs = model_unwrapped.generate_method_a(batch) - - # Both should have valid shapes - actual_bs = batch.actual_batch_size - top_k = model_unwrapped.top_k_for_generation - assert orig_sids.shape == method_a_sids.shape == ( - actual_bs, top_k, num_hierarchies - ) - assert orig_probs.shape == method_a_probs.shape == (actual_bs, top_k) - - # Both should produce valid SIDs - for h in range(num_hierarchies): - assert torch.all(orig_sids[:, :, h] >= 0) - assert torch.all(orig_sids[:, :, h] < codebook_sizes[h]) - assert torch.all(method_a_sids[:, :, h] >= 0) - assert torch.all(method_a_sids[:, :, h] < codebook_sizes[h]) diff --git a/examples/sid_gr/tests/test_method_a_generate.py b/examples/sid_gr/tests/test_method_a_generate.py deleted file mode 100644 index c44bc8ed0..000000000 --- a/examples/sid_gr/tests/test_method_a_generate.py +++ /dev/null @@ -1,330 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for Method A (Incremental Append) beam search: - 1. Beam search ancestor tracking. - 2. Dense mask structure for Method A. - 3. arbitrary_func matches the dense reference mask. -""" - -import os -import sys - -import pytest -import torch - -from beam_search.beam_search import BeamSearch - -# Import attention_mask directly to avoid model/__init__.py which pulls in -# heavy dependencies (dynamicemb, megatron, torchrec). -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) -from attention_mask import ( - build_incremental_append_arbitrary_func, - build_incremental_append_dense_mask, -) -sys.path.pop(0) - - -# --------------------------------------------------------------------------- -# Helper: expand arbitrary_func to dense bool mask for comparison -# --------------------------------------------------------------------------- -def arbitrary_func_to_dense( - af: torch.Tensor, - seqlen_q: int, - seqlen_k: int, -) -> torch.Tensor: - """ - Expand an arbitrary_func tensor [B, 1, n_func, seqlen_q+256] into a dense - [B, seqlen_q, seqlen_k] bool mask using the interval semantics: - valid(q, k) = (k < F0[q]) OR (F1[q] <= k < F2[q]) OR ... - """ - B = af.shape[0] - n_func = af.shape[2] - mask = torch.zeros(B, seqlen_q, seqlen_k, dtype=torch.bool, device=af.device) - - kv_idx = torch.arange(seqlen_k, device=af.device) - - for b in range(B): - for q in range(seqlen_q): - f0 = af[b, 0, 0, q].item() - row_mask = kv_idx < f0 - for interval in range(n_func // 2): - f_start = af[b, 0, 2 * interval + 1, q].item() - f_end = af[b, 0, 2 * interval + 2, q].item() - row_mask = row_mask | ((kv_idx >= f_start) & (kv_idx < f_end)) - mask[b, q] = row_mask - - return mask - - -def _run_beam_search(batch_size, beam_width, num_hierarchies, codebook_sizes): - """Run a full beam search with random logits and return the BeamSearch object.""" - bs = BeamSearch(beam_width, num_hierarchies, codebook_sizes, record_history=True) - topk_prev = 1 - for step in range(num_hierarchies): - log_probs = torch.randn( - batch_size, topk_prev, codebook_sizes[step], device="cuda" - ) - bs.propagate(log_probs) - topk_prev = beam_width - return bs - - -# --------------------------------------------------------------------------- -# Test: beam search ancestor tracking -# --------------------------------------------------------------------------- -class TestBeamSearchAncestorTracking: - @pytest.mark.parametrize("batch_size", [1, 3]) - @pytest.mark.parametrize("beam_width", [2, 4]) - def test_ancestor_positions_shape(self, batch_size, beam_width): - """Verify shape and range of ancestor positions.""" - num_h = 3 - codebook_sizes = [10] * num_h - bs = _run_beam_search(batch_size, beam_width, num_h, codebook_sizes) - - hist_len = torch.full((batch_size,), 6, device="cuda", dtype=torch.long) - anc_pos = bs.get_ancestor_positions(hist_len) - - assert anc_pos is not None - assert anc_pos.shape == (batch_size, beam_width, num_h) - - for s in range(num_h): - offset = sum(bs.beam_widths[:s]) - for b_idx in range(batch_size): - hl = hist_len[b_idx].item() - positions = anc_pos[b_idx, :, s] - assert torch.all(positions >= hl + offset) - assert torch.all(positions < hl + offset + bs.beam_widths[s]) - - def test_ancestor_positions_step_zero(self): - """At step 0, get_ancestor_positions should return None.""" - bs = BeamSearch(2, 3, [10, 10, 10], record_history=True) - assert bs.get_ancestor_positions(torch.tensor([6], device="cuda")) is None - - def test_parent_indices_stored(self): - """Verify parent_indices are stored during propagate.""" - bs = _run_beam_search(1, 2, 3, [10, 10, 10]) - assert len(bs.parent_indices) == 3 - for pi in bs.parent_indices: - assert pi.shape[0] == 1 - assert pi.shape[1] == 2 - - def test_reset_clears_parent_indices(self): - """Verify reset() clears parent_indices.""" - bs = _run_beam_search(1, 2, 2, [10, 10]) - assert len(bs.parent_indices) == 2 - bs.reset() - assert len(bs.parent_indices) == 0 - - def test_ancestor_self_position(self): - """The last entry in ancestor_positions should be the token's own position.""" - bs = _run_beam_search(1, 2, 3, [10, 10, 10]) - hist_len = torch.tensor([6], device="cuda", dtype=torch.long) - anc_pos = bs.get_ancestor_positions(hist_len) - - last_step_offset = sum(bs.beam_widths[:2]) - for b in range(2): - expected_self_pos = 6 + last_step_offset + b - assert anc_pos[0, b, 2].item() == expected_self_pos - - -# --------------------------------------------------------------------------- -# Test: dense mask structure for Method A -# --------------------------------------------------------------------------- -class TestMethodADenseMask: - def test_step0_pure_causal(self): - """At step 0, mask is pure causal over history.""" - B, hist_len, max_hist = 1, 4, 4 - history_seqlens = torch.tensor([hist_len], device="cuda") - - mask = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=0, - beam_widths=[2], ancestor_positions=None, - ) - assert mask.shape == (B, max_hist, max_hist) - expected = torch.tril(torch.ones(hist_len, hist_len, dtype=torch.bool, device="cuda")) - assert torch.equal(mask[0, :hist_len, :hist_len], expected) - - def test_step1_beam_isolation(self): - """At step 1, step-0 beams should not see each other.""" - B, hist_len, max_hist = 1, 4, 4 - beam_width = 2 - history_seqlens = torch.tensor([hist_len], device="cuda") - - bs = _run_beam_search(B, beam_width, 1, [10]) - anc_pos = bs.get_ancestor_positions(history_seqlens) - - mask = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=1, - beam_widths=[beam_width, beam_width], - ancestor_positions=anc_pos, - ) - N = max_hist + beam_width - assert mask.shape == (B, N, N) - - beam_a_pos = max_hist - beam_b_pos = max_hist + 1 - - # Beams should NOT see each other - assert not mask[0, beam_a_pos, beam_b_pos].item() - assert not mask[0, beam_b_pos, beam_a_pos].item() - - # Both should see self - assert mask[0, beam_a_pos, beam_a_pos].item() - assert mask[0, beam_b_pos, beam_b_pos].item() - - # Both should see all history - assert torch.all(mask[0, beam_a_pos, :hist_len]) - assert torch.all(mask[0, beam_b_pos, :hist_len]) - - def test_tree_ancestry_step2(self): - """At step 2, each beam sees its ancestor chain but not other branches.""" - B, hist_len, max_hist = 1, 4, 4 - beam_width = 2 - history_seqlens = torch.tensor([hist_len], device="cuda") - - bs = _run_beam_search(B, beam_width, 2, [10, 10]) - anc_pos = bs.get_ancestor_positions(history_seqlens) - - mask = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=2, - beam_widths=[beam_width, beam_width, beam_width], - ancestor_positions=anc_pos, - ) - N = max_hist + beam_width * 2 - assert mask.shape == (B, N, N) - - step1_start = max_hist + beam_width - for b in range(beam_width): - token_pos = step1_start + b - # Should see all history - assert torch.all(mask[0, token_pos, :hist_len]) - # Should see self - assert mask[0, token_pos, token_pos].item() - # Should see its step-0 ancestor - step0_ancestor = anc_pos[0, b, 0].item() - assert mask[0, token_pos, step0_ancestor].item() - - def test_padding_handled(self): - """Samples with different history lengths should have correct padding.""" - B = 2 - history_seqlens = torch.tensor([4, 2], device="cuda") - max_hist = 4 - - mask = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=0, - beam_widths=[2], ancestor_positions=None, - ) - # Sample 0: 4 history tokens - assert mask[0, 3, 3].item() # pos 3 sees self - # Sample 1: only 2 history tokens - assert mask[1, 1, 1].item() # pos 1 sees self - assert not mask[1, 2, 2].item() # pos 2 is padding, all zeros - assert not mask[1, 3, 3].item() # pos 3 is padding - - -# --------------------------------------------------------------------------- -# Test: arbitrary_func matches dense mask for last-step tokens -# --------------------------------------------------------------------------- -class TestArbitraryFuncMatchesDense: - @pytest.mark.parametrize("beam_width", [2, 3]) - @pytest.mark.parametrize("num_hierarchies", [1, 2, 3]) - def test_last_step_tokens_match(self, beam_width, num_hierarchies): - """ - The arbitrary_func encoding should produce the same mask as the dense - reference for the last step's tokens (the logit-producing tokens). - """ - B = 2 - hist_len = 6 - max_hist = hist_len - codebook_sizes = [10] * num_hierarchies - history_seqlens = torch.tensor([hist_len] * B, device="cuda") - beam_widths = [beam_width] * num_hierarchies - - bs = _run_beam_search(B, beam_width, num_hierarchies, codebook_sizes) - - # Test mask at the point of predicting the LAST hierarchy - test_step = num_hierarchies - anc_pos = bs.get_ancestor_positions(history_seqlens) - - dense = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=test_step, - beam_widths=beam_widths, ancestor_positions=anc_pos, - ) - - af = build_incremental_append_arbitrary_func( - history_seqlens, max_hist, current_step=test_step, - beam_widths=beam_widths, ancestor_positions=anc_pos, - ) - - total_gen = sum(beam_widths[:test_step]) - N = max_hist + total_gen - af_dense = arbitrary_func_to_dense(af, N, N) - - # Check that last-step tokens match exactly - step_start = max_hist + sum(beam_widths[: test_step - 1]) - for b_idx in range(B): - for pos in range(step_start, step_start + beam_width): - dense_row = dense[b_idx, pos, :N] - af_row = af_dense[b_idx, pos, :N] - assert torch.equal(dense_row, af_row), ( - f"Mismatch at batch={b_idx}, pos={pos}\n" - f" dense: {dense_row.int().tolist()}\n" - f" af: {af_row.int().tolist()}" - ) - - def test_step0_causal_match(self): - """At step 0, arbitrary_func should produce a pure causal mask.""" - B = 1 - hist_len = 6 - history_seqlens = torch.tensor([hist_len], device="cuda") - - af = build_incremental_append_arbitrary_func( - history_seqlens, hist_len, current_step=0, - beam_widths=[2], ancestor_positions=None, - ) - - af_dense = arbitrary_func_to_dense(af, hist_len, hist_len) - expected = torch.tril(torch.ones(hist_len, hist_len, dtype=torch.bool, device="cuda")) - assert torch.equal(af_dense[0], expected) - - def test_history_region_matches(self): - """History region should be identical between dense and arbitrary_func.""" - B = 1 - hist_len = 6 - max_hist = hist_len - beam_widths = [2, 2] - history_seqlens = torch.tensor([hist_len], device="cuda") - - bs = _run_beam_search(B, 2, 2, [10, 10]) - anc_pos = bs.get_ancestor_positions(history_seqlens) - - dense = build_incremental_append_dense_mask( - history_seqlens, max_hist, current_step=2, - beam_widths=beam_widths, ancestor_positions=anc_pos, - ) - af = build_incremental_append_arbitrary_func( - history_seqlens, max_hist, current_step=2, - beam_widths=beam_widths, ancestor_positions=anc_pos, - ) - - N = max_hist + sum(beam_widths[:2]) - af_dense = arbitrary_func_to_dense(af, N, N) - - # History region [0:hist_len, 0:hist_len] should match - assert torch.equal( - dense[0, :hist_len, :hist_len], - af_dense[0, :hist_len, :hist_len], - ) From 61399ef5dba3b1bcef443ca38421962ff12de3ff Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Wed, 1 Apr 2026 09:04:22 +0000 Subject: [PATCH 04/10] Remove mcore transformerblock. --- examples/sid_gr/model/gpt_model.py | 153 +++++------------- .../sid_gr/model/jagged_flash_attn_block.py | 82 ++++++++++ examples/sid_gr/model/mcore_model_specs.py | 2 +- 3 files changed, 120 insertions(+), 117 deletions(-) diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index f17395428..9407b0b85 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -16,7 +16,7 @@ import torch from beam_search.beam_search import BeamSearch -from commons.datasets.gpt_sid_batch import GPTSIDBatch, to_packed_seq_params +from commons.datasets.gpt_sid_batch import GPTSIDBatch from commons.modules.embedding import ShardedEmbedding, ShardedEmbeddingConfig from commons.ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat from commons.ops.length_to_offsets import length_to_complete_offsets @@ -27,11 +27,9 @@ from megatron.core.models.common.embeddings.relative_pos_embedding import ( RelativePositionEmbedding, ) -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.transformer_block import TransformerBlock from modules.eval_metrics import SIDRetrievalEvaluator from modules.gpt_loss_module import GPTSIDLossModule from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor @@ -40,6 +38,7 @@ padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) +from .jagged_flash_attn_block import JaggedTransformerBlock def _padding_to_dense_and_transpose( @@ -117,36 +116,40 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, ) - self.decoder = TransformerBlock( - config=self.config, - spec=self.transformer_decoder_layer_spec, + self.decoder = JaggedTransformerBlock( + num_layers=self.config.num_layers, + hidden_size=self.config.hidden_size, + num_attention_heads=self.config.num_attention_heads, + ffn_hidden_size=self.config.ffn_hidden_size, + layernorm_epsilon=getattr(self.config, "layernorm_epsilon", 1e-5), ) def forward( self, - hidden_states, - attention_mask: Optional[ - torch.Tensor - ] = None, # decoder attention mask, always causal + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, *, - packed_seq_params: Optional[PackedSeqParams] = None, + offsets: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + arbitrary_func: Optional[torch.Tensor] = None, **kwargs: Any, ) -> torch.Tensor: - attention_bias = None - # if self.position_embedding_type == 'relative': - # # attention bias is supported by cudnn, but not fa. - # # TODO@junzhang add jagged support once we have attention kernels - # query_seq_length = input_max_seqlen - # key_seq_length = query_seq_length - # attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length) - output = self.decoder( - hidden_states=hidden_states, # query - attention_mask=attention_mask, # attention mask - packed_seq_params=packed_seq_params, # query and kv seqlens - attention_bias=attention_bias, - **kwargs, + """ + Args: + hidden_states: jagged [total_tokens, hidden_size]. + attention_mask: [B, 1, N, N] dense mask (True = masked out). + offsets: [B+1] cumulative sequence-length offsets. + max_seqlen: maximum sequence length in the batch. + arbitrary_func: pre-built arbitrary_func tensor; if provided, + takes precedence over attention_mask. + """ + return self.decoder( + hidden_states=hidden_states, + offsets=offsets, + max_seqlen=max_seqlen, + attention_mask=attention_mask, + arbitrary_func=arbitrary_func, ) - return output class SIDGRModel(MegatronModule): @@ -513,29 +516,21 @@ def decoder_step( input_offsets: torch.Tensor, input_max_seqlen: int, attention_mask: Optional[torch.Tensor] = None, - padding_to_dense: bool = True, add_bos_to_history: bool = False, - use_flash_attn: bool = False, ) -> torch.Tensor: """ Input and Output are both jagged. - attention_mask is used only when padding_to_dense is True. - When attention mask is None, we will construct a causal attention mask if padding_to_dense is True. - When use_flash_attn is True, uses JaggedFlashAttnBlock with jiayus's - Flash Attention + arbitrary_func mask encoding instead of Megatron's - TransformerBlock. The dense attention_mask is automatically converted - to arbitrary_func format. + The underlying JaggedTransformerBlock handles jagged↔dense conversion + and dense-mask → arbitrary_func conversion internally. """ if add_bos_to_history: assert ( attention_mask is None ), "attention mask should be None when adding bos to history" input_offsets[-1].item() - batch_size = input_offsets.size(0) - 1 - # Build attention_mask if not provided (needed by both paths) - if attention_mask is None and padding_to_dense: + if attention_mask is None: attention_mask = padded_causal_mask_with_optional_bos( input_offsets, input_max_seqlen, @@ -543,85 +538,12 @@ def decoder_step( bos_interval=self._num_hierarchies, ) - if use_flash_attn: - from .attention_mask import dense_mask_to_arbitrary_func - from .jagged_flash_attn_block import JaggedFlashAttnBlock - - # Lazy-init the FA decoder block - if not hasattr(self, "_jagged_fa_decoder"): - self._jagged_fa_decoder = JaggedFlashAttnBlock( - num_layers=self.config.num_layers, - hidden_size=self.embedding_dim, - num_attention_heads=self.config.num_attention_heads, - ffn_hidden_size=self.config.ffn_hidden_size, - layernorm_epsilon=getattr( - self.config, "layernorm_epsilon", 1e-5 - ), - ).to( - device=input_hidden_states.device, - dtype=self._training_dtype, - ) - - # Pad jagged → dense [B, S, D] - padded_input = ( - torch.ops.fbgemm.jagged_to_padded_dense( - values=input_hidden_states, - offsets=[input_offsets], - max_lengths=[input_max_seqlen], - padding_value=0.0, - ) - .view(batch_size, input_max_seqlen, -1) - .to(self._training_dtype) - ) - - # Convert dense mask → arbitrary_func - # attention_mask convention: True = masked out → invert for valid mask - valid_mask = ~attention_mask # [B, 1, N, N] - arbitrary_func = dense_mask_to_arbitrary_func( - valid_mask, input_max_seqlen - ) - - padded_output = self._jagged_fa_decoder( - padded_input, - arbitrary_func=arbitrary_func, - seqlen=input_max_seqlen, - ) - - # Unpad dense → jagged - output_hidden_states = torch.ops.fbgemm.dense_to_jagged( - padded_output, [input_offsets] - )[0] - else: - # Original Megatron path - if padding_to_dense: - decoder_input_hidden_states = _padding_to_dense_and_transpose( - input_hidden_states, - input_offsets, - input_max_seqlen, - ) - packed_seq_params = None - else: - assert input_hidden_states.dim() == 2, "input_hidden_states should be 2D" - decoder_input_hidden_states = input_hidden_states.unsqueeze(1) - attention_mask = None - packed_seq_params = to_packed_seq_params( - input_offsets, - input_max_seqlen, - ) - decoder_output_hidden_states = self.decoder( - hidden_states=decoder_input_hidden_states, - attention_mask=attention_mask, - packed_seq_params=packed_seq_params, - ) - - if padding_to_dense: - output_hidden_states = _transpose_dense_to_jagged( - decoder_output_hidden_states, - input_offsets, - input_max_seqlen, - ) - else: - output_hidden_states = decoder_output_hidden_states.squeeze(1) + output_hidden_states = self.decoder( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + offsets=input_offsets, + max_seqlen=input_max_seqlen, + ) return output_hidden_states def forward( @@ -793,7 +715,6 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: cated_offsets, cated_max_seqlen, attention_mask=attention_mask, - padding_to_dense=True, add_bos_to_history=False, ) # remove history[batchsize * topk_last_step * max(1,i), embedding_dim] diff --git a/examples/sid_gr/model/jagged_flash_attn_block.py b/examples/sid_gr/model/jagged_flash_attn_block.py index 4dd7c4ce8..21c1dcb9c 100644 --- a/examples/sid_gr/model/jagged_flash_attn_block.py +++ b/examples/sid_gr/model/jagged_flash_attn_block.py @@ -334,3 +334,85 @@ def forward( hidden_states = self.final_layernorm(hidden_states) return hidden_states + + +class JaggedTransformerBlock(nn.Module): + """ + Wrapper that accepts jagged (variable-length) hidden states and either a + dense attention mask or a pre-built arbitrary_func tensor. + + Internally handles: + 1. dense_mask → arbitrary_func conversion (when dense mask is provided) + 2. jagged → padded-dense conversion + 3. forward through JaggedFlashAttnBlock (FA with arbitrary mask) + 4. padded-dense → jagged conversion + + This is intended to replace Megatron-Core's TransformerBlock in + SIDGRDecoder so that all padding / mask-format logic is encapsulated + in a single module initialised at construction time. + """ + + def __init__( + self, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.block = JaggedFlashAttnBlock( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + activation=activation, + ) + + def forward( + self, + hidden_states: torch.Tensor, + offsets: torch.Tensor, + max_seqlen: int, + attention_mask: Optional[torch.Tensor] = None, + arbitrary_func: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: jagged [total_tokens, hidden_size]. + offsets: [B+1] cumulative sequence-length offsets. + max_seqlen: maximum sequence length in the batch. + attention_mask: optional [B, 1, N, N] bool tensor where + True = masked-out (Megatron convention). Ignored when + *arbitrary_func* is provided. + arbitrary_func: optional [B, 1, n_func, N+pad] int32 tensor + in flash_attn interval encoding. Takes precedence over + *attention_mask* when both are supplied. + + Returns: + jagged output [total_tokens, hidden_size]. + """ + from .attention_mask import dense_mask_to_arbitrary_func + + batch_size = offsets.size(0) - 1 + + if arbitrary_func is None and attention_mask is not None: + valid_mask = ~attention_mask # Megatron: True=masked → invert + arbitrary_func = dense_mask_to_arbitrary_func(valid_mask, max_seqlen) + + padded = torch.ops.fbgemm.jagged_to_padded_dense( + values=hidden_states, + offsets=[offsets], + max_lengths=[max_seqlen], + padding_value=0.0, + ).view(batch_size, max_seqlen, -1) + + output = self.block( + padded, arbitrary_func=arbitrary_func, seqlen=max_seqlen + ) + + return torch.ops.fbgemm.dense_to_jagged(output, [offsets])[0] diff --git a/examples/sid_gr/model/mcore_model_specs.py b/examples/sid_gr/model/mcore_model_specs.py index d5a57392a..78ee7bc62 100644 --- a/examples/sid_gr/model/mcore_model_specs.py +++ b/examples/sid_gr/model/mcore_model_specs.py @@ -76,7 +76,6 @@ warnings.warn("Apex is not installed. Falling back to Torch Norm") LNImpl = WrappedTorchNorm - def _get_gpt_layer_with_transformer_engine_spec( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, @@ -375,6 +374,7 @@ def get_gpt_decoder_block_spec( ), "arbitrary attention mask is only supported with Megatron-Core modules" # Layer specs. + # dense_layer_spec = ( _get_gpt_layer_with_transformer_engine_spec( num_experts=None, From 74e07c8144b8961fca64f62a9476cb10556ff585 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Thu, 2 Apr 2026 03:36:16 +0000 Subject: [PATCH 05/10] Add use_jagged_flash_attn to control FA version. --- examples/sid_gr/model/__init__.py | 3 +- examples/sid_gr/model/attention_mask.py | 119 ++++++++++++++ examples/sid_gr/model/gpt_model.py | 145 ++++++++++++++---- .../sid_gr/model/jagged_flash_attn_block.py | 58 +++---- .../test_dense_mask_to_arbitrary_func.py | 88 +++++++++++ 5 files changed, 347 insertions(+), 66 deletions(-) diff --git a/examples/sid_gr/model/__init__.py b/examples/sid_gr/model/__init__.py index 46499783e..3fb4bb5f6 100644 --- a/examples/sid_gr/model/__init__.py +++ b/examples/sid_gr/model/__init__.py @@ -18,6 +18,7 @@ def get_sid_gr_model( top_k_for_generation: int = 10, eval_metrics: Tuple[str, ...] = (), share_lm_head_across_hierarchies: bool = True, + use_jagged_flash_attn: bool = True, ) -> SIDGRModel: sid_gr_model = SIDGRModel( decoder_config=decoder_config, @@ -25,7 +26,6 @@ def get_sid_gr_model( codebook_sizes=codebook_sizes, num_hierarchies=num_hierarchies, transformer_decoder_layer_spec=get_gpt_decoder_block_spec( - # padding + arbitrary attention mask + Megatron-Core decoder_config, use_transformer_engine=False, arbitrary_attention_mask=True, @@ -35,6 +35,7 @@ def get_sid_gr_model( top_k_for_generation=top_k_for_generation, eval_metrics=eval_metrics, share_lm_head_across_hierarchies=share_lm_head_across_hierarchies, + use_jagged_flash_attn=use_jagged_flash_attn, ) return sid_gr_model diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py index ac5d01897..014276477 100644 --- a/examples/sid_gr/model/attention_mask.py +++ b/examples/sid_gr/model/attention_mask.py @@ -282,6 +282,125 @@ def dense_mask_to_arbitrary_func( return af +def build_jagged_causal_arbitrary_func( + offsets: torch.Tensor, + total_tokens: int, + padding: int = 256, +) -> torch.Tensor: + """ + Build arbitrary_func for flattened jagged causal attention (B=1). + + All batch sequences are concatenated into a single sequence of length + *total_tokens*. Each query at global position *q* in batch element *b* + can attend to keys in ``[offset[b], q+1)`` — standard causal within its + own sequence, invisible to other sequences. + + Args: + offsets: [B+1] cumulative sequence-length offsets. + total_tokens: ``offsets[-1].item()`` — total number of tokens. + padding: FA convention padding on the last dim (default 256). + + Returns: + arbitrary_func: [1, 1, 3, total_tokens + padding] int32 tensor. + """ + device = offsets.device + n_func = 3 # F0=0, single interval [F1, F2) + + af = torch.zeros( + 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device + ) + + positions = torch.arange(total_tokens, device=device) + batch_ids = torch.searchsorted(offsets[1:], positions, right=True) + batch_starts = offsets[batch_ids] + + # visible(q) = [0, 0) ∪ [batch_start, q+1) + af[0, 0, 1, :total_tokens] = batch_starts.to(torch.int32) + af[0, 0, 2, :total_tokens] = (positions + 1).to(torch.int32) + + return af + + +def dense_mask_to_jagged_arbitrary_func( + valid_mask: torch.Tensor, + offsets: torch.Tensor, + total_tokens: int, + padding: int = 256, +) -> torch.Tensor: + """ + Convert per-batch dense bool mask to a flattened (B=1) arbitrary_func. + + The dense mask is in per-batch padded coordinates ``[B, N, N]``. This + function maps each row to global (flattened) coordinates and encodes the + visible intervals into a single arbitrary_func tensor of shape + ``[1, 1, n_func, total_tokens + padding]``. + + Use this when a pre-built dense mask is available (e.g. from + ``padded_target_aware_causal_mask``). For pure causal masks prefer + :func:`build_jagged_causal_arbitrary_func` which avoids the dense mask + entirely. + + Args: + valid_mask: [B, N, N] or [B, 1, N, N] bool (True = can attend). + offsets: [B+1] cumulative offsets. + total_tokens: ``offsets[-1].item()``. + padding: FA convention padding (default 256). + + Returns: + arbitrary_func: [1, 1, n_func, total_tokens + padding] int32 tensor. + """ + if valid_mask.dim() == 4: + valid_mask = valid_mask.squeeze(1) + assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" + + B, N, _ = valid_mask.shape + device = valid_mask.device + + # Detect interval boundaries + shifted = torch.zeros_like(valid_mask) + shifted[:, :, 1:] = valid_mask[:, :, :-1] + starts = valid_mask & ~shifted + + ends_shifted = torch.zeros_like(valid_mask) + ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] + ends = valid_mask & ~ends_shifted + + max_intervals = int(starts.sum(dim=-1).max().item()) + n_func = max(2 * max_intervals + 1, 3) + if n_func % 2 == 0: + n_func += 1 + + af = torch.zeros( + 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device + ) + + for b in range(B): + batch_start = offsets[b].item() + batch_end = offsets[b + 1].item() + seq_len = batch_end - batch_start + + for local_q in range(seq_len): + global_q = batch_start + local_q + row = valid_mask[b, local_q, :seq_len] + + if not row.any(): + continue + + start_pos = starts[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + end_pos = ends[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + 1 + + # In flattened coordinates, the first visible key is at + # batch_start (not 0), so F0 is always 0. All intervals go + # into the explicit (F1,F2), (F3,F4), ... slots. + for iv in range(len(start_pos)): + s = start_pos[iv].item() + batch_start + e = end_pos[iv].item() + batch_start + af[0, 0, 2 * iv + 1, global_q] = s + af[0, 0, 2 * iv + 2, global_q] = e + + return af + + if __name__ == "__main__": history_seqlen = torch.tensor([4, 3]).cuda() max_history_seqlen = 6 diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index 9407b0b85..1ba2b9c35 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -30,11 +30,14 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_block import TransformerBlock from modules.eval_metrics import SIDRetrievalEvaluator from modules.gpt_loss_module import GPTSIDLossModule from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from .attention_mask import ( + build_jagged_causal_arbitrary_func, + dense_mask_to_jagged_arbitrary_func, padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) @@ -88,7 +91,15 @@ def _transpose_dense_to_jagged( class SIDGRDecoder(MegatronModule): """ - Don't support PP currently. Does not inclu de embedding + Don't support PP currently. Does not include embedding. + + Supports two backend modes controlled by *use_jagged_flash_attn*: + + * **True** (default) — ``JaggedTransformerBlock``: flattens all batch + sequences into one (B=1) and uses jiayus's Flash Attention with + arbitrary_func mask encoding. Zero padding. + * **False** — Megatron-Core ``TransformerBlock``: pads jagged to dense, + uses ``DotProductAttention`` with a dense arbitrary attention mask. """ def __init__( @@ -100,15 +111,16 @@ def __init__( ] = "learned_absolute", relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, + use_jagged_flash_attn: bool = True, ): super().__init__(config=decoder_config) self.config: TransformerConfig = decoder_config self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec - # TODO, add position encoder self.model_type = ModelType.encoder_or_decoder self.position_embedding_type = position_embedding_type + self.use_jagged_flash_attn = use_jagged_flash_attn self.decoder_relative_pos_emb = RelativePositionEmbedding( bidirectional=False, init_method=self.config.init_method, @@ -116,13 +128,22 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, ) - self.decoder = JaggedTransformerBlock( - num_layers=self.config.num_layers, - hidden_size=self.config.hidden_size, - num_attention_heads=self.config.num_attention_heads, - ffn_hidden_size=self.config.ffn_hidden_size, - layernorm_epsilon=getattr(self.config, "layernorm_epsilon", 1e-5), - ) + + if use_jagged_flash_attn: + self.decoder = JaggedTransformerBlock( + num_layers=self.config.num_layers, + hidden_size=self.config.hidden_size, + num_attention_heads=self.config.num_attention_heads, + ffn_hidden_size=self.config.ffn_hidden_size, + layernorm_epsilon=getattr( + self.config, "layernorm_epsilon", 1e-5 + ), + ) + else: + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + ) def forward( self, @@ -134,22 +155,21 @@ def forward( arbitrary_func: Optional[torch.Tensor] = None, **kwargs: Any, ) -> torch.Tensor: - """ - Args: - hidden_states: jagged [total_tokens, hidden_size]. - attention_mask: [B, 1, N, N] dense mask (True = masked out). - offsets: [B+1] cumulative sequence-length offsets. - max_seqlen: maximum sequence length in the batch. - arbitrary_func: pre-built arbitrary_func tensor; if provided, - takes precedence over attention_mask. - """ - return self.decoder( - hidden_states=hidden_states, - offsets=offsets, - max_seqlen=max_seqlen, - attention_mask=attention_mask, - arbitrary_func=arbitrary_func, - ) + if self.use_jagged_flash_attn: + return self.decoder( + hidden_states=hidden_states, + arbitrary_func=arbitrary_func, + ) + else: + # mcore path: expects dense [S, B, D] input + padded = _padding_to_dense_and_transpose( + hidden_states, offsets, max_seqlen + ) + output = self.decoder( + hidden_states=padded, + attention_mask=attention_mask, + ) + return _transpose_dense_to_jagged(output, offsets, max_seqlen) class SIDGRModel(MegatronModule): @@ -174,6 +194,7 @@ def __init__( top_k_for_generation: int = 10, # this is used for eval eval_metrics: Tuple[str, ...] = (), # this is used for eval share_lm_head_across_hierarchies: bool = True, + use_jagged_flash_attn: bool = True, ): super(SIDGRModel, self).__init__(config=decoder_config) assert ( @@ -202,6 +223,7 @@ def __init__( decoder_config, transformer_decoder_layer_spec, position_embedding_type="relative", + use_jagged_flash_attn=use_jagged_flash_attn, ) self.codebook_sizes = codebook_sizes assert codebook_embedding_config.vocab_size >= sum( @@ -521,8 +543,9 @@ def decoder_step( """ Input and Output are both jagged. - The underlying JaggedTransformerBlock handles jagged↔dense conversion - and dense-mask → arbitrary_func conversion internally. + Routes to either the JaggedFlashAttn path (B=1 flatten + + arbitrary_func) or the mcore path (pad-to-dense + dense mask) + depending on ``self.decoder.use_jagged_flash_attn``. """ if add_bos_to_history: assert ( @@ -530,6 +553,71 @@ def decoder_step( ), "attention mask should be None when adding bos to history" input_offsets[-1].item() + if self.decoder.use_jagged_flash_attn: + return self._decoder_step_jagged_fa( + input_hidden_states, + input_offsets, + input_max_seqlen, + attention_mask, + add_bos_to_history, + ) + else: + return self._decoder_step_mcore( + input_hidden_states, + input_offsets, + input_max_seqlen, + attention_mask, + add_bos_to_history, + ) + + def _decoder_step_jagged_fa( + self, + input_hidden_states: torch.Tensor, + input_offsets: torch.Tensor, + input_max_seqlen: int, + attention_mask: Optional[torch.Tensor], + add_bos_to_history: bool, + ) -> torch.Tensor: + """Flatten to B=1, build arbitrary_func, forward through FA block.""" + total_tokens = int(input_offsets[-1].item()) + + if attention_mask is not None: + valid_mask = ~attention_mask + arbitrary_func = dense_mask_to_jagged_arbitrary_func( + valid_mask, input_offsets, total_tokens + ) + elif add_bos_to_history: + # TODO: build flattened arbitrary_func directly from offsets + # to avoid materialising the dense mask. + dense_mask = padded_causal_mask_with_optional_bos( + input_offsets, + input_max_seqlen, + add_bos_to_history=True, + bos_interval=self._num_hierarchies, + ) + valid_mask = ~dense_mask + arbitrary_func = dense_mask_to_jagged_arbitrary_func( + valid_mask, input_offsets, total_tokens + ) + else: + arbitrary_func = build_jagged_causal_arbitrary_func( + input_offsets, total_tokens + ) + + return self.decoder( + hidden_states=input_hidden_states, + arbitrary_func=arbitrary_func, + ) + + def _decoder_step_mcore( + self, + input_hidden_states: torch.Tensor, + input_offsets: torch.Tensor, + input_max_seqlen: int, + attention_mask: Optional[torch.Tensor], + add_bos_to_history: bool, + ) -> torch.Tensor: + """Pad to dense, build dense mask, forward through mcore block.""" if attention_mask is None: attention_mask = padded_causal_mask_with_optional_bos( input_offsets, @@ -538,13 +626,12 @@ def decoder_step( bos_interval=self._num_hierarchies, ) - output_hidden_states = self.decoder( + return self.decoder( hidden_states=input_hidden_states, attention_mask=attention_mask, offsets=input_offsets, max_seqlen=input_max_seqlen, ) - return output_hidden_states def forward( self, diff --git a/examples/sid_gr/model/jagged_flash_attn_block.py b/examples/sid_gr/model/jagged_flash_attn_block.py index 21c1dcb9c..244cc882e 100644 --- a/examples/sid_gr/model/jagged_flash_attn_block.py +++ b/examples/sid_gr/model/jagged_flash_attn_block.py @@ -338,18 +338,21 @@ def forward( class JaggedTransformerBlock(nn.Module): """ - Wrapper that accepts jagged (variable-length) hidden states and either a - dense attention mask or a pre-built arbitrary_func tensor. + Wrapper that accepts jagged (variable-length) hidden states and a + pre-built arbitrary_func tensor in the flattened (B=1) coordinate space. - Internally handles: - 1. dense_mask → arbitrary_func conversion (when dense mask is provided) - 2. jagged → padded-dense conversion - 3. forward through JaggedFlashAttnBlock (FA with arbitrary mask) - 4. padded-dense → jagged conversion + All batch sequences are concatenated into a single sequence of length + *total_tokens* (no padding). The arbitrary_func encodes both the + block-diagonal batch isolation and the desired attention pattern + (causal, target-grouped, etc.). + + Internally: + 1. Reshape jagged [total_tokens, D] → [1, total_tokens, D] + 2. Forward through JaggedFlashAttnBlock (FA with arbitrary mask) + 3. Reshape [1, total_tokens, D] → [total_tokens, D] This is intended to replace Megatron-Core's TransformerBlock in - SIDGRDecoder so that all padding / mask-format logic is encapsulated - in a single module initialised at construction time. + SIDGRDecoder. """ def __init__( @@ -376,43 +379,26 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - offsets: torch.Tensor, - max_seqlen: int, - attention_mask: Optional[torch.Tensor] = None, - arbitrary_func: Optional[torch.Tensor] = None, + arbitrary_func: torch.Tensor, ) -> torch.Tensor: """ Args: hidden_states: jagged [total_tokens, hidden_size]. - offsets: [B+1] cumulative sequence-length offsets. - max_seqlen: maximum sequence length in the batch. - attention_mask: optional [B, 1, N, N] bool tensor where - True = masked-out (Megatron convention). Ignored when - *arbitrary_func* is provided. - arbitrary_func: optional [B, 1, n_func, N+pad] int32 tensor - in flash_attn interval encoding. Takes precedence over - *attention_mask* when both are supplied. + arbitrary_func: [1, 1, n_func, total_tokens + pad] int32 tensor + in flattened (B=1) coordinate space, encoding both batch + isolation and the attention pattern. Returns: jagged output [total_tokens, hidden_size]. """ - from .attention_mask import dense_mask_to_arbitrary_func - - batch_size = offsets.size(0) - 1 - - if arbitrary_func is None and attention_mask is not None: - valid_mask = ~attention_mask # Megatron: True=masked → invert - arbitrary_func = dense_mask_to_arbitrary_func(valid_mask, max_seqlen) + total_tokens = hidden_states.shape[0] - padded = torch.ops.fbgemm.jagged_to_padded_dense( - values=hidden_states, - offsets=[offsets], - max_lengths=[max_seqlen], - padding_value=0.0, - ).view(batch_size, max_seqlen, -1) + # [total_tokens, D] → [1, total_tokens, D] + flat_input = hidden_states.unsqueeze(0) output = self.block( - padded, arbitrary_func=arbitrary_func, seqlen=max_seqlen + flat_input, arbitrary_func=arbitrary_func, seqlen=total_tokens ) - return torch.ops.fbgemm.dense_to_jagged(output, [offsets])[0] + # [1, total_tokens, D] → [total_tokens, D] + return output.squeeze(0) diff --git a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py index 56a5b13b3..bee414b1f 100644 --- a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py +++ b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py @@ -24,7 +24,9 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) from attention_mask import ( + build_jagged_causal_arbitrary_func, dense_mask_to_arbitrary_func, + dense_mask_to_jagged_arbitrary_func, padded_target_aware_causal_mask, ) sys.path.pop(0) @@ -109,3 +111,89 @@ def test_batch_independence(self): recon = arbitrary_func_to_dense(af, N, N) assert torch.equal(valid[0], recon[0]) assert torch.equal(valid[1], recon[1]) + + +class TestJaggedFlattenedArbitraryFunc: + """Tests for the B=1 flattened arbitrary_func builders.""" + + @staticmethod + def _build_expected_jagged_causal(offsets): + """Build the expected [1, total, total] causal block-diagonal mask.""" + total = offsets[-1].item() + device = offsets.device + expected = torch.zeros(total, total, dtype=torch.bool, device=device) + B = offsets.size(0) - 1 + for b in range(B): + s, e = offsets[b].item(), offsets[b + 1].item() + block = torch.tril(torch.ones(e - s, e - s, dtype=torch.bool, device=device)) + expected[s:e, s:e] = block + return expected + + def test_jagged_causal_basic(self): + offsets = torch.tensor([0, 4, 7, 10], device="cuda") + total = 10 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + def test_jagged_causal_single_batch(self): + offsets = torch.tensor([0, 6], device="cuda") + total = 6 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = torch.tril(torch.ones(6, 6, dtype=torch.bool, device="cuda")) + assert torch.equal(expected, recon) + + def test_jagged_causal_uneven_lengths(self): + offsets = torch.tensor([0, 2, 8, 9], device="cuda") + total = 9 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + def test_dense_to_jagged_causal(self): + """dense_mask_to_jagged_arbitrary_func should match build_jagged_causal for causal masks.""" + offsets = torch.tensor([0, 3, 7], device="cuda") + B, total = 2, 7 + max_seqlen = 4 + per_batch = torch.zeros(B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda") + for b in range(B): + sl = (offsets[b + 1] - offsets[b]).item() + per_batch[b, :sl, :sl] = torch.tril( + torch.ones(sl, sl, dtype=torch.bool, device="cuda") + ) + af = dense_mask_to_jagged_arbitrary_func(per_batch, offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + @pytest.mark.parametrize("beam_width", [2, 3]) + @pytest.mark.parametrize("candidate_len", [1, 3]) + def test_dense_to_jagged_target_grouped(self, beam_width, candidate_len): + """Verify target-grouped masks survive the jagged conversion roundtrip.""" + B = 2 + hist_lens = torch.tensor([5, 3], device="cuda") + max_hist = 5 + inverted = padded_target_aware_causal_mask(hist_lens, max_hist, beam_width, candidate_len) + valid = ~inverted # [B, 1, N, N] + N = valid.shape[-1] + total_per_batch = (hist_lens + beam_width * candidate_len).tolist() + offsets = torch.tensor( + [0] + [sum(total_per_batch[:i + 1]) for i in range(B)], + device="cuda", + ) + total = offsets[-1].item() + af = dense_mask_to_jagged_arbitrary_func(valid, offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + + # Build expected flattened mask from per-batch dense mask + expected = torch.zeros(total, total, dtype=torch.bool, device="cuda") + valid_3d = valid.squeeze(1) + for b in range(B): + s = offsets[b].item() + sl = total_per_batch[b] + expected[s:s + sl, s:s + sl] = valid_3d[b, :sl, :sl] + + assert torch.equal(expected, recon) From 528cf7760390f4faa4a71b049c8908cb64627ccc Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Thu, 2 Apr 2026 07:34:52 +0000 Subject: [PATCH 06/10] Move build func to tests. --- examples/sid_gr/model/attention_mask.py | 80 --------- examples/sid_gr/model/gpt_model.py | 133 ++++---------- .../test_dense_mask_to_arbitrary_func.py | 61 ++++++- .../tests/test_jagged_flash_attn_block.py | 163 ++++++++++++++++++ 4 files changed, 253 insertions(+), 184 deletions(-) diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py index 014276477..17a63d60c 100644 --- a/examples/sid_gr/model/attention_mask.py +++ b/examples/sid_gr/model/attention_mask.py @@ -321,86 +321,6 @@ def build_jagged_causal_arbitrary_func( return af -def dense_mask_to_jagged_arbitrary_func( - valid_mask: torch.Tensor, - offsets: torch.Tensor, - total_tokens: int, - padding: int = 256, -) -> torch.Tensor: - """ - Convert per-batch dense bool mask to a flattened (B=1) arbitrary_func. - - The dense mask is in per-batch padded coordinates ``[B, N, N]``. This - function maps each row to global (flattened) coordinates and encodes the - visible intervals into a single arbitrary_func tensor of shape - ``[1, 1, n_func, total_tokens + padding]``. - - Use this when a pre-built dense mask is available (e.g. from - ``padded_target_aware_causal_mask``). For pure causal masks prefer - :func:`build_jagged_causal_arbitrary_func` which avoids the dense mask - entirely. - - Args: - valid_mask: [B, N, N] or [B, 1, N, N] bool (True = can attend). - offsets: [B+1] cumulative offsets. - total_tokens: ``offsets[-1].item()``. - padding: FA convention padding (default 256). - - Returns: - arbitrary_func: [1, 1, n_func, total_tokens + padding] int32 tensor. - """ - if valid_mask.dim() == 4: - valid_mask = valid_mask.squeeze(1) - assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" - - B, N, _ = valid_mask.shape - device = valid_mask.device - - # Detect interval boundaries - shifted = torch.zeros_like(valid_mask) - shifted[:, :, 1:] = valid_mask[:, :, :-1] - starts = valid_mask & ~shifted - - ends_shifted = torch.zeros_like(valid_mask) - ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] - ends = valid_mask & ~ends_shifted - - max_intervals = int(starts.sum(dim=-1).max().item()) - n_func = max(2 * max_intervals + 1, 3) - if n_func % 2 == 0: - n_func += 1 - - af = torch.zeros( - 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device - ) - - for b in range(B): - batch_start = offsets[b].item() - batch_end = offsets[b + 1].item() - seq_len = batch_end - batch_start - - for local_q in range(seq_len): - global_q = batch_start + local_q - row = valid_mask[b, local_q, :seq_len] - - if not row.any(): - continue - - start_pos = starts[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) - end_pos = ends[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + 1 - - # In flattened coordinates, the first visible key is at - # batch_start (not 0), so F0 is always 0. All intervals go - # into the explicit (F1,F2), (F3,F4), ... slots. - for iv in range(len(start_pos)): - s = start_pos[iv].item() + batch_start - e = end_pos[iv].item() + batch_start - af[0, 0, 2 * iv + 1, global_q] = s - af[0, 0, 2 * iv + 2, global_q] = e - - return af - - if __name__ == "__main__": history_seqlen = torch.tensor([4, 3]).cuda() max_history_seqlen = 6 diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index 1ba2b9c35..2fc59fb22 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -35,12 +35,6 @@ from modules.gpt_loss_module import GPTSIDLossModule from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor -from .attention_mask import ( - build_jagged_causal_arbitrary_func, - dense_mask_to_jagged_arbitrary_func, - padded_causal_mask_with_optional_bos, - padded_target_aware_causal_mask, -) from .jagged_flash_attn_block import JaggedTransformerBlock @@ -538,104 +532,39 @@ def decoder_step( input_offsets: torch.Tensor, input_max_seqlen: int, attention_mask: Optional[torch.Tensor] = None, - add_bos_to_history: bool = False, + arbitrary_func: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Input and Output are both jagged. + Input and Output are both jagged. This method only routes — the + caller is responsible for building the mask / arbitrary_func. - Routes to either the JaggedFlashAttn path (B=1 flatten + - arbitrary_func) or the mcore path (pad-to-dense + dense mask) - depending on ``self.decoder.use_jagged_flash_attn``. + * FA path: pass *arbitrary_func* (caller-built). + * mcore path: pass *attention_mask* (dense ``[B, 1, N, N]``). """ - if add_bos_to_history: - assert ( - attention_mask is None - ), "attention mask should be None when adding bos to history" - input_offsets[-1].item() - if self.decoder.use_jagged_flash_attn: - return self._decoder_step_jagged_fa( - input_hidden_states, - input_offsets, - input_max_seqlen, - attention_mask, - add_bos_to_history, + assert arbitrary_func is not None, ( + "FA path requires arbitrary_func; caller should build it" ) - else: - return self._decoder_step_mcore( - input_hidden_states, - input_offsets, - input_max_seqlen, - attention_mask, - add_bos_to_history, - ) - - def _decoder_step_jagged_fa( - self, - input_hidden_states: torch.Tensor, - input_offsets: torch.Tensor, - input_max_seqlen: int, - attention_mask: Optional[torch.Tensor], - add_bos_to_history: bool, - ) -> torch.Tensor: - """Flatten to B=1, build arbitrary_func, forward through FA block.""" - total_tokens = int(input_offsets[-1].item()) - - if attention_mask is not None: - valid_mask = ~attention_mask - arbitrary_func = dense_mask_to_jagged_arbitrary_func( - valid_mask, input_offsets, total_tokens - ) - elif add_bos_to_history: - # TODO: build flattened arbitrary_func directly from offsets - # to avoid materialising the dense mask. - dense_mask = padded_causal_mask_with_optional_bos( - input_offsets, - input_max_seqlen, - add_bos_to_history=True, - bos_interval=self._num_hierarchies, - ) - valid_mask = ~dense_mask - arbitrary_func = dense_mask_to_jagged_arbitrary_func( - valid_mask, input_offsets, total_tokens + return self.decoder( + hidden_states=input_hidden_states, + arbitrary_func=arbitrary_func, ) else: - arbitrary_func = build_jagged_causal_arbitrary_func( - input_offsets, total_tokens + assert attention_mask is not None, ( + "mcore path requires attention_mask; caller should build it" ) - - return self.decoder( - hidden_states=input_hidden_states, - arbitrary_func=arbitrary_func, - ) - - def _decoder_step_mcore( - self, - input_hidden_states: torch.Tensor, - input_offsets: torch.Tensor, - input_max_seqlen: int, - attention_mask: Optional[torch.Tensor], - add_bos_to_history: bool, - ) -> torch.Tensor: - """Pad to dense, build dense mask, forward through mcore block.""" - if attention_mask is None: - attention_mask = padded_causal_mask_with_optional_bos( - input_offsets, - input_max_seqlen, - add_bos_to_history=add_bos_to_history, - bos_interval=self._num_hierarchies, + return self.decoder( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + offsets=input_offsets, + max_seqlen=input_max_seqlen, ) - return self.decoder( - hidden_states=input_hidden_states, - attention_mask=attention_mask, - offsets=input_offsets, - max_seqlen=input_max_seqlen, - ) - def forward( self, batch: GPTSIDBatch, + attention_mask: Optional[torch.Tensor] = None, + arbitrary_func: Optional[torch.Tensor] = None, ) -> torch.Tensor: # 1. prepare embeddings: embedding lookup + history, bos and candidate concat ( @@ -649,13 +578,13 @@ def forward( ) history_offsets = batch.features[batch.history_feature_name].offsets() - # 2. decoder step + # 2. decoder step — caller provides the mask / arbitrary_func jagged_output_hidden_states = self.decoder_step( input_hidden_states, input_offsets, input_max_seqlen, - attention_mask=None, - add_bos_to_history=self.add_bos_to_history_for_training, + attention_mask=attention_mask, + arbitrary_func=arbitrary_func, ) # 3. postprocess: only keep the candidate hidden states candidate_hidden_states = self._postprocess_output( @@ -699,7 +628,12 @@ def forward( return merged_losses, merged_logits @torch.no_grad - def generate(self, batch: GPTSIDBatch) -> torch.Tensor: + def generate( + self, + batch: GPTSIDBatch, + attention_mask: Optional[torch.Tensor] = None, + arbitrary_func: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Generate the output sids for the given batch. The generation will autogressively generate the output sids with a constrained fixed-width beam search strategy. Args: @@ -789,20 +723,13 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: dtype=input_offsets.dtype, ) - # 2. prepare the attention mask - attention_mask = padded_target_aware_causal_mask( - torch.diff(input_offsets), - input_max_seqlen, - 0 if i == 0 else topk_prev_step, - candidate_length, - ) - # 3. we need a decoder step with the concatenated hidden states and offsets. Note that we do not add bos to history for generation. + # 2. decoder step — caller provides attention_mask or arbitrary_func jagged_output_hidden_states = self.decoder_step( cated_hidden_states, cated_offsets, cated_max_seqlen, attention_mask=attention_mask, - add_bos_to_history=False, + arbitrary_func=arbitrary_func, ) # remove history[batchsize * topk_last_step * max(1,i), embedding_dim] _, candidate_hidden_states = triton_split_2D_jagged( diff --git a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py index bee414b1f..acbe364fd 100644 --- a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py +++ b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py @@ -26,12 +26,71 @@ from attention_mask import ( build_jagged_causal_arbitrary_func, dense_mask_to_arbitrary_func, - dense_mask_to_jagged_arbitrary_func, + padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) sys.path.pop(0) +def dense_mask_to_jagged_arbitrary_func( + valid_mask: torch.Tensor, + offsets: torch.Tensor, + total_tokens: int, + padding: int = 256, +) -> torch.Tensor: + """ + Test utility: convert per-batch dense mask [B, N, N] to a flattened + (B=1) arbitrary_func [1, 1, n_func, total_tokens + padding]. + + Maps each row from per-batch padded coordinates to global (flattened) + coordinates. Used in tests to build arbitrary_func from dense masks + so the FA path can be compared against the mcore/reference path. + """ + if valid_mask.dim() == 4: + valid_mask = valid_mask.squeeze(1) + assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" + + B, N, _ = valid_mask.shape + device = valid_mask.device + + shifted = torch.zeros_like(valid_mask) + shifted[:, :, 1:] = valid_mask[:, :, :-1] + starts = valid_mask & ~shifted + + ends_shifted = torch.zeros_like(valid_mask) + ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] + ends = valid_mask & ~ends_shifted + + max_intervals = int(starts.sum(dim=-1).max().item()) + n_func = max(2 * max_intervals + 1, 3) + if n_func % 2 == 0: + n_func += 1 + + af = torch.zeros( + 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device + ) + + for b in range(B): + batch_start = offsets[b].item() + batch_end = offsets[b + 1].item() + seq_len = batch_end - batch_start + + for local_q in range(seq_len): + global_q = batch_start + local_q + row = valid_mask[b, local_q, :seq_len] + if not row.any(): + continue + start_pos = starts[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + end_pos = ends[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + 1 + for iv in range(len(start_pos)): + s = start_pos[iv].item() + batch_start + e = end_pos[iv].item() + batch_start + af[0, 0, 2 * iv + 1, global_q] = s + af[0, 0, 2 * iv + 2, global_q] = e + + return af + + def arbitrary_func_to_dense(af, seqlen_q, seqlen_k): """Expand arbitrary_func back to dense [B, seqlen_q, seqlen_k] bool mask.""" B, n_func = af.shape[0], af.shape[2] diff --git a/examples/sid_gr/tests/test_jagged_flash_attn_block.py b/examples/sid_gr/tests/test_jagged_flash_attn_block.py index ca89b3fef..2f0e27efa 100644 --- a/examples/sid_gr/tests/test_jagged_flash_attn_block.py +++ b/examples/sid_gr/tests/test_jagged_flash_attn_block.py @@ -37,7 +37,9 @@ from jagged_flash_attn_block import ( JaggedFlashAttnBlock, JaggedGPTLayer, + JaggedTransformerBlock, ) +from attention_mask import build_jagged_causal_arbitrary_func sys.path.pop(0) try: @@ -257,3 +259,164 @@ def test_backward_gradient_correctness(self): torch.testing.assert_close( x_test.grad, x_ref.grad, atol=5e-2, rtol=5e-2 ) + + +# --------------------------------------------------------------------------- +# JaggedTransformerBlock (B=1 flatten) vs padded reference +# --------------------------------------------------------------------------- +class ReferenceBlock(nn.Module): + """Stack of ReferenceGPTLayers + final layernorm for comparison.""" + + def __init__(self, num_layers, hidden_size, num_heads, ffn_hidden_size, eps=1e-5): + super().__init__() + self.layers = nn.ModuleList( + [ReferenceGPTLayer(hidden_size, num_heads, ffn_hidden_size, eps) + for _ in range(num_layers)] + ) + self.final_layernorm = nn.LayerNorm(hidden_size, eps=eps) + + def forward(self, x, attn_mask=None, is_causal=False): + for layer in self.layers: + x = layer(x, attn_mask=attn_mask, is_causal=is_causal) + return self.final_layernorm(x) + + +def _build_padded_causal_block_diagonal_mask(offsets, max_seqlen): + """Build [B, 1, S, S] causal + padding mask (True = can attend).""" + B = offsets.size(0) - 1 + device = offsets.device + mask = torch.zeros(B, 1, max_seqlen, max_seqlen, dtype=torch.bool, device=device) + for b in range(B): + sl = (offsets[b + 1] - offsets[b]).item() + mask[b, 0, :sl, :sl] = torch.tril( + torch.ones(sl, sl, dtype=torch.bool, device=device) + ) + return mask + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedTransformerBlockVsPadded: + """ + Compare the B=1 flattened FA path (JaggedTransformerBlock) against the + padded PyTorch reference to make sure the flatten + arbitrary_func + approach produces correct outputs. + """ + + @pytest.mark.parametrize("num_layers", [1, 2]) + def test_causal_flatten_vs_padded(self, num_layers): + """ + For jagged causal attention, the flattened (B=1) FA path should + produce the same per-sequence output as the padded reference. + """ + H, NH, FFN = 256, 4, 1024 + seq_lens = [5, 8, 3] + offsets = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seq_lens), 0).tolist()), + device="cuda", + ) + total_tokens = offsets[-1].item() + max_seqlen = max(seq_lens) + B = len(seq_lens) + torch.manual_seed(123) + + ref_block = ReferenceBlock(num_layers, H, NH, FFN).cuda().bfloat16() + test_block = JaggedTransformerBlock( + num_layers=num_layers, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ).cuda().bfloat16() + + # Copy weights: test_block.block.layers[i] ↔ ref_block.layers[i] + for i in range(num_layers): + test_block.block.layers[i].load_state_dict( + ref_block.layers[i].state_dict() + ) + test_block.block.final_layernorm.load_state_dict( + ref_block.final_layernorm.state_dict() + ) + + # Build jagged input [total_tokens, H] + jagged_input = torch.randn( + total_tokens, H, device="cuda", dtype=torch.bfloat16 + ) + + # --- FA path: flatten to B=1, build arbitrary_func --- + arbitrary_func = build_jagged_causal_arbitrary_func(offsets, total_tokens) + with torch.no_grad(): + fa_output = test_block(jagged_input, arbitrary_func=arbitrary_func) + + # --- Reference path: pad each sequence, run per-batch --- + with torch.no_grad(): + for b in range(B): + sl = seq_lens[b] + s = offsets[b].item() + padded_in = torch.zeros( + 1, max_seqlen, H, device="cuda", dtype=torch.bfloat16 + ) + padded_in[0, :sl] = jagged_input[s:s + sl] + ref_out = ref_block(padded_in, is_causal=True) + fa_seq = fa_output[s:s + sl] + ref_seq = ref_out[0, :sl] + torch.testing.assert_close( + fa_seq, ref_seq, atol=5e-2, rtol=5e-2, + msg=f"Mismatch in batch element {b}", + ) + + def test_flatten_output_shape(self): + """Output shape should be [total_tokens, hidden_size].""" + H, NH, FFN = 128, 2, 512 + offsets = torch.tensor([0, 4, 9], device="cuda") + total = 9 + + block = JaggedTransformerBlock( + num_layers=1, hidden_size=H, + num_attention_heads=NH, ffn_hidden_size=FFN, + ).cuda().bfloat16() + + x = torch.randn(total, H, device="cuda", dtype=torch.bfloat16) + af = build_jagged_causal_arbitrary_func(offsets, total) + + with torch.no_grad(): + out = block(x, arbitrary_func=af) + + assert out.shape == (total, H) + + def test_batch_isolation(self): + """ + Sequences in different batch elements must not affect each other. + Running [A, B] flattened should give the same per-sequence output + as running [A] and [B] separately. + """ + H, NH, FFN = 256, 4, 1024 + torch.manual_seed(77) + + block = JaggedTransformerBlock( + num_layers=1, hidden_size=H, + num_attention_heads=NH, ffn_hidden_size=FFN, + ).cuda().bfloat16() + + seq_a = torch.randn(5, H, device="cuda", dtype=torch.bfloat16) + seq_b = torch.randn(7, H, device="cuda", dtype=torch.bfloat16) + + # Run combined [A, B] + combined = torch.cat([seq_a, seq_b], dim=0) + offsets_ab = torch.tensor([0, 5, 12], device="cuda") + af_ab = build_jagged_causal_arbitrary_func(offsets_ab, 12) + with torch.no_grad(): + out_ab = block(combined, arbitrary_func=af_ab) + + # Run A alone + offsets_a = torch.tensor([0, 5], device="cuda") + af_a = build_jagged_causal_arbitrary_func(offsets_a, 5) + with torch.no_grad(): + out_a = block(seq_a, arbitrary_func=af_a) + + # Run B alone + offsets_b = torch.tensor([0, 7], device="cuda") + af_b = build_jagged_causal_arbitrary_func(offsets_b, 7) + with torch.no_grad(): + out_b = block(seq_b, arbitrary_func=af_b) + + torch.testing.assert_close(out_ab[:5], out_a, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(out_ab[5:], out_b, atol=1e-3, rtol=1e-3) From 98e61a38270dff739752d75bab8824298bfb2156 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Thu, 2 Apr 2026 09:05:16 +0000 Subject: [PATCH 07/10] Fix precommit errors. --- examples/sid_gr/model/gpt_model.py | 23 +-- .../sid_gr/model/jagged_flash_attn_block.py | 60 +++--- examples/sid_gr/model/mcore_model_specs.py | 3 +- .../test_dense_mask_to_arbitrary_func.py | 25 ++- .../tests/test_jagged_flash_attn_block.py | 176 +++++++++++------- 5 files changed, 170 insertions(+), 117 deletions(-) diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index 2fc59fb22..a01b5cceb 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -129,9 +129,7 @@ def __init__( hidden_size=self.config.hidden_size, num_attention_heads=self.config.num_attention_heads, ffn_hidden_size=self.config.ffn_hidden_size, - layernorm_epsilon=getattr( - self.config, "layernorm_epsilon", 1e-5 - ), + layernorm_epsilon=getattr(self.config, "layernorm_epsilon", 1e-5), ) else: self.decoder = TransformerBlock( @@ -156,9 +154,8 @@ def forward( ) else: # mcore path: expects dense [S, B, D] input - padded = _padding_to_dense_and_transpose( - hidden_states, offsets, max_seqlen - ) + assert offsets is not None and max_seqlen is not None + padded = _padding_to_dense_and_transpose(hidden_states, offsets, max_seqlen) output = self.decoder( hidden_states=padded, attention_mask=attention_mask, @@ -542,17 +539,17 @@ def decoder_step( * mcore path: pass *attention_mask* (dense ``[B, 1, N, N]``). """ if self.decoder.use_jagged_flash_attn: - assert arbitrary_func is not None, ( - "FA path requires arbitrary_func; caller should build it" - ) + assert ( + arbitrary_func is not None + ), "FA path requires arbitrary_func; caller should build it" return self.decoder( hidden_states=input_hidden_states, arbitrary_func=arbitrary_func, ) else: - assert attention_mask is not None, ( - "mcore path requires attention_mask; caller should build it" - ) + assert ( + attention_mask is not None + ), "mcore path requires attention_mask; caller should build it" return self.decoder( hidden_states=input_hidden_states, attention_mask=attention_mask, @@ -642,7 +639,6 @@ def generate( torch.Tensor: The generated sids. """ - attention_mask: Optional[torch.Tensor] = None # 0. prepare history and bos embeddings. Note that we do not append bos to history. ( history_embeddings, @@ -765,4 +761,3 @@ def generate( generated_sids = self.beam_search.get_sids() log_probs = self.beam_search.get_log_probs() return generated_sids, log_probs - diff --git a/examples/sid_gr/model/jagged_flash_attn_block.py b/examples/sid_gr/model/jagged_flash_attn_block.py index 244cc882e..2c569bd08 100644 --- a/examples/sid_gr/model/jagged_flash_attn_block.py +++ b/examples/sid_gr/model/jagged_flash_attn_block.py @@ -26,7 +26,7 @@ Reference: examples/hstu/modules/native_hstu_layer.py """ -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -52,7 +52,7 @@ def build_block_sparsity( try: import create_block_mask_cuda from flash_attn.cute.block_sparsity import LinearBlockSparseTensorsTorch - from flash_attn.utils.tile_size import get_tile_sizes_by_backend, get_arch + from flash_attn.utils.tile_size import get_arch, get_tile_sizes_by_backend except ImportError: return None, None @@ -76,15 +76,20 @@ def build_block_sparsity( is_arbitrary=True, ) - (k_cnt, k_off, k_idx, k_fcnt, k_foff, k_fidx) = ( - create_block_mask_cuda.create_q2k_csr_sparse_from_func( - arbitrary_func, - seqlen_q, - seqlen_k, - Q_BLOCK_SIZE=fwd_q_block, - KV_BLOCK_SIZE=fwd_kv_block, - check_q_boundary=True, - ) + ( + k_cnt, + k_off, + k_idx, + k_fcnt, + k_foff, + k_fidx, + ) = create_block_mask_cuda.create_q2k_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=fwd_q_block, + KV_BLOCK_SIZE=fwd_kv_block, + check_q_boundary=True, ) linear_k = LinearBlockSparseTensorsTorch( mask_block_cnt=k_cnt, @@ -95,14 +100,19 @@ def build_block_sparsity( full_block_idx=k_fidx, ) - (q_cnt, q_off, q_idx, q_fcnt, q_foff, q_fidx) = ( - create_block_mask_cuda.create_k2q_csr_sparse_from_func( - arbitrary_func, - seqlen_q, - seqlen_k, - Q_BLOCK_SIZE=bwd_q_block, - KV_BLOCK_SIZE=bwd_kv_block, - ) + ( + q_cnt, + q_off, + q_idx, + q_fcnt, + q_foff, + q_fidx, + ) = create_block_mask_cuda.create_k2q_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=bwd_q_block, + KV_BLOCK_SIZE=bwd_kv_block, ) linear_q = LinearBlockSparseTensorsTorch( mask_block_cnt=q_cnt, @@ -145,9 +155,7 @@ def __init__( self.ffn_hidden_size = ffn_hidden_size # --- Attention sub-layers --- - self.input_layernorm = nn.LayerNorm( - hidden_size, eps=layernorm_epsilon - ) + self.input_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) # Fused QKV projection: hidden_size → 3 * hidden_size self.linear_qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) # Output projection after attention @@ -155,9 +163,7 @@ def __init__( self.attn_dropout = nn.Dropout(hidden_dropout) # --- FFN sub-layers --- - self.pre_mlp_layernorm = nn.LayerNorm( - hidden_size, eps=layernorm_epsilon - ) + self.pre_mlp_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) self.mlp_fc1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False) self.mlp_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) self.mlp_dropout = nn.Dropout(hidden_dropout) @@ -293,9 +299,7 @@ def __init__( for _ in range(num_layers) ] ) - self.final_layernorm = nn.LayerNorm( - hidden_size, eps=layernorm_epsilon - ) + self.final_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) def forward( self, diff --git a/examples/sid_gr/model/mcore_model_specs.py b/examples/sid_gr/model/mcore_model_specs.py index 78ee7bc62..b4ad2ed02 100644 --- a/examples/sid_gr/model/mcore_model_specs.py +++ b/examples/sid_gr/model/mcore_model_specs.py @@ -76,6 +76,7 @@ warnings.warn("Apex is not installed. Falling back to Torch Norm") LNImpl = WrappedTorchNorm + def _get_gpt_layer_with_transformer_engine_spec( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, @@ -374,7 +375,7 @@ def get_gpt_decoder_block_spec( ), "arbitrary attention mask is only supported with Megatron-Core modules" # Layer specs. - # + # dense_layer_spec = ( _get_gpt_layer_with_transformer_engine_spec( num_experts=None, diff --git a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py index acbe364fd..d32ca1882 100644 --- a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py +++ b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py @@ -26,9 +26,9 @@ from attention_mask import ( build_jagged_causal_arbitrary_func, dense_mask_to_arbitrary_func, - padded_causal_mask_with_optional_bos, padded_target_aware_causal_mask, ) + sys.path.pop(0) @@ -138,9 +138,10 @@ def test_block_diagonal(self): @pytest.mark.parametrize("beam_width", [2, 3]) @pytest.mark.parametrize("candidate_len", [1, 3]) def test_target_aware_causal_mask(self, beam_width, candidate_len): - B = 2 hist_lens = torch.tensor([6, 4], device="cuda") - inverted = padded_target_aware_causal_mask(hist_lens, 6, beam_width, candidate_len) + inverted = padded_target_aware_causal_mask( + hist_lens, 6, beam_width, candidate_len + ) valid = ~inverted N = valid.shape[-1] af = dense_mask_to_arbitrary_func(valid, N) @@ -184,7 +185,9 @@ def _build_expected_jagged_causal(offsets): B = offsets.size(0) - 1 for b in range(B): s, e = offsets[b].item(), offsets[b + 1].item() - block = torch.tril(torch.ones(e - s, e - s, dtype=torch.bool, device=device)) + block = torch.tril( + torch.ones(e - s, e - s, dtype=torch.bool, device=device) + ) expected[s:e, s:e] = block return expected @@ -217,7 +220,9 @@ def test_dense_to_jagged_causal(self): offsets = torch.tensor([0, 3, 7], device="cuda") B, total = 2, 7 max_seqlen = 4 - per_batch = torch.zeros(B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda") + per_batch = torch.zeros( + B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda" + ) for b in range(B): sl = (offsets[b + 1] - offsets[b]).item() per_batch[b, :sl, :sl] = torch.tril( @@ -235,12 +240,14 @@ def test_dense_to_jagged_target_grouped(self, beam_width, candidate_len): B = 2 hist_lens = torch.tensor([5, 3], device="cuda") max_hist = 5 - inverted = padded_target_aware_causal_mask(hist_lens, max_hist, beam_width, candidate_len) + inverted = padded_target_aware_causal_mask( + hist_lens, max_hist, beam_width, candidate_len + ) valid = ~inverted # [B, 1, N, N] - N = valid.shape[-1] + valid.shape[-1] total_per_batch = (hist_lens + beam_width * candidate_len).tolist() offsets = torch.tensor( - [0] + [sum(total_per_batch[:i + 1]) for i in range(B)], + [0] + [sum(total_per_batch[: i + 1]) for i in range(B)], device="cuda", ) total = offsets[-1].item() @@ -253,6 +260,6 @@ def test_dense_to_jagged_target_grouped(self, beam_width, candidate_len): for b in range(B): s = offsets[b].item() sl = total_per_batch[b] - expected[s:s + sl, s:s + sl] = valid_3d[b, :sl, :sl] + expected[s : s + sl, s : s + sl] = valid_3d[b, :sl, :sl] assert torch.equal(expected, recon) diff --git a/examples/sid_gr/tests/test_jagged_flash_attn_block.py b/examples/sid_gr/tests/test_jagged_flash_attn_block.py index 2f0e27efa..d3979409b 100644 --- a/examples/sid_gr/tests/test_jagged_flash_attn_block.py +++ b/examples/sid_gr/tests/test_jagged_flash_attn_block.py @@ -26,24 +26,27 @@ - Backward: gradients flow correctly """ +import os +import sys + import pytest import torch import torch.nn as nn import torch.nn.functional as F -import sys -import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from attention_mask import build_jagged_causal_arbitrary_func from jagged_flash_attn_block import ( JaggedFlashAttnBlock, JaggedGPTLayer, JaggedTransformerBlock, ) -from attention_mask import build_jagged_causal_arbitrary_func + sys.path.pop(0) try: from flash_attn.cute.interface import flash_attn_func # noqa: F401 + HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False @@ -130,12 +133,16 @@ def test_forward_shape(self, batch_size, seqlen, hidden_size, num_heads): def test_block_forward_shape(self): """JaggedFlashAttnBlock stacks layers correctly.""" - block = JaggedFlashAttnBlock( - num_layers=2, - hidden_size=256, - num_attention_heads=4, - ffn_hidden_size=1024, - ).cuda().bfloat16() + block = ( + JaggedFlashAttnBlock( + num_layers=2, + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + ) + .cuda() + .bfloat16() + ) x = torch.randn(2, 32, 256, device="cuda", dtype=torch.bfloat16) out = block(x) @@ -157,13 +164,19 @@ def test_causal_matches_reference(self, hidden_size, num_heads, seqlen): ffn_size = hidden_size * 4 torch.manual_seed(42) - ref_layer = ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() - test_layer = JaggedGPTLayer( - hidden_size=hidden_size, - num_attention_heads=num_heads, - ffn_hidden_size=ffn_size, - hidden_dropout=0.0, - ).cuda().bfloat16() + ref_layer = ( + ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + ) + test_layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) test_layer.load_state_dict(ref_layer.state_dict()) @@ -185,12 +198,16 @@ def test_arbitrary_causal_matches_standard_causal(self, hidden_size, num_heads): ffn_size = hidden_size * 4 torch.manual_seed(42) - layer = JaggedGPTLayer( - hidden_size=hidden_size, - num_attention_heads=num_heads, - ffn_hidden_size=ffn_size, - hidden_dropout=0.0, - ).cuda().bfloat16() + layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) x = torch.randn(B, S, hidden_size, device="cuda", dtype=torch.bfloat16) @@ -213,14 +230,20 @@ class TestJaggedGPTLayerBackward: def test_backward_runs(self): """Forward + backward should not error.""" - layer = JaggedGPTLayer( - hidden_size=256, - num_attention_heads=4, - ffn_hidden_size=1024, - hidden_dropout=0.0, - ).cuda().bfloat16() + layer = ( + JaggedGPTLayer( + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) - x = torch.randn(2, 16, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True) + x = torch.randn( + 2, 16, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) out = layer(x) loss = out.sum() loss.backward() @@ -237,16 +260,24 @@ def test_backward_gradient_correctness(self): B, S = 2, 16 torch.manual_seed(42) - ref_layer = ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() - test_layer = JaggedGPTLayer( - hidden_size=hidden_size, - num_attention_heads=num_heads, - ffn_hidden_size=ffn_size, - hidden_dropout=0.0, - ).cuda().bfloat16() + ref_layer = ( + ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + ) + test_layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) test_layer.load_state_dict(ref_layer.state_dict()) - x_ref = torch.randn(B, S, hidden_size, device="cuda", dtype=torch.bfloat16, requires_grad=True) + x_ref = torch.randn( + B, S, hidden_size, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) x_test = x_ref.detach().clone().requires_grad_(True) ref_out = ref_layer(x_ref, is_causal=True) @@ -256,9 +287,7 @@ def test_backward_gradient_correctness(self): ref_out.backward(dout) test_out.backward(dout) - torch.testing.assert_close( - x_test.grad, x_ref.grad, atol=5e-2, rtol=5e-2 - ) + torch.testing.assert_close(x_test.grad, x_ref.grad, atol=5e-2, rtol=5e-2) # --------------------------------------------------------------------------- @@ -270,8 +299,10 @@ class ReferenceBlock(nn.Module): def __init__(self, num_layers, hidden_size, num_heads, ffn_hidden_size, eps=1e-5): super().__init__() self.layers = nn.ModuleList( - [ReferenceGPTLayer(hidden_size, num_heads, ffn_hidden_size, eps) - for _ in range(num_layers)] + [ + ReferenceGPTLayer(hidden_size, num_heads, ffn_hidden_size, eps) + for _ in range(num_layers) + ] ) self.final_layernorm = nn.LayerNorm(hidden_size, eps=eps) @@ -320,26 +351,26 @@ def test_causal_flatten_vs_padded(self, num_layers): torch.manual_seed(123) ref_block = ReferenceBlock(num_layers, H, NH, FFN).cuda().bfloat16() - test_block = JaggedTransformerBlock( - num_layers=num_layers, - hidden_size=H, - num_attention_heads=NH, - ffn_hidden_size=FFN, - ).cuda().bfloat16() + test_block = ( + JaggedTransformerBlock( + num_layers=num_layers, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) # Copy weights: test_block.block.layers[i] ↔ ref_block.layers[i] for i in range(num_layers): - test_block.block.layers[i].load_state_dict( - ref_block.layers[i].state_dict() - ) + test_block.block.layers[i].load_state_dict(ref_block.layers[i].state_dict()) test_block.block.final_layernorm.load_state_dict( ref_block.final_layernorm.state_dict() ) # Build jagged input [total_tokens, H] - jagged_input = torch.randn( - total_tokens, H, device="cuda", dtype=torch.bfloat16 - ) + jagged_input = torch.randn(total_tokens, H, device="cuda", dtype=torch.bfloat16) # --- FA path: flatten to B=1, build arbitrary_func --- arbitrary_func = build_jagged_causal_arbitrary_func(offsets, total_tokens) @@ -354,12 +385,15 @@ def test_causal_flatten_vs_padded(self, num_layers): padded_in = torch.zeros( 1, max_seqlen, H, device="cuda", dtype=torch.bfloat16 ) - padded_in[0, :sl] = jagged_input[s:s + sl] + padded_in[0, :sl] = jagged_input[s : s + sl] ref_out = ref_block(padded_in, is_causal=True) - fa_seq = fa_output[s:s + sl] + fa_seq = fa_output[s : s + sl] ref_seq = ref_out[0, :sl] torch.testing.assert_close( - fa_seq, ref_seq, atol=5e-2, rtol=5e-2, + fa_seq, + ref_seq, + atol=5e-2, + rtol=5e-2, msg=f"Mismatch in batch element {b}", ) @@ -369,10 +403,16 @@ def test_flatten_output_shape(self): offsets = torch.tensor([0, 4, 9], device="cuda") total = 9 - block = JaggedTransformerBlock( - num_layers=1, hidden_size=H, - num_attention_heads=NH, ffn_hidden_size=FFN, - ).cuda().bfloat16() + block = ( + JaggedTransformerBlock( + num_layers=1, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) x = torch.randn(total, H, device="cuda", dtype=torch.bfloat16) af = build_jagged_causal_arbitrary_func(offsets, total) @@ -391,10 +431,16 @@ def test_batch_isolation(self): H, NH, FFN = 256, 4, 1024 torch.manual_seed(77) - block = JaggedTransformerBlock( - num_layers=1, hidden_size=H, - num_attention_heads=NH, ffn_hidden_size=FFN, - ).cuda().bfloat16() + block = ( + JaggedTransformerBlock( + num_layers=1, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) seq_a = torch.randn(5, H, device="cuda", dtype=torch.bfloat16) seq_b = torch.randn(7, H, device="cuda", dtype=torch.bfloat16) From ad536dd5ed3effa94485c1f7e975eb8c3b7a88a6 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 7 Apr 2026 02:45:18 +0000 Subject: [PATCH 08/10] update dockerfile and readme.md. --- docker/Dockerfile | 4 ++++ examples/sid_gr/training/README.md | 20 ++++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 772ea3ca0..90a702046 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -46,6 +46,10 @@ RUN pip install --no-deps tensordict orjson && \ RUN pip install nvidia-cutlass-dsl==4.3.0 +RUN git clone -b arbitrary_mask https://github.com/jiayus-nvidia/flash-attention.git flash-attention && \ + cd flash-attention && \ + pip install --no-deps -e . + # for dev RUN apt update -y --fix-missing && \ diff --git a/examples/sid_gr/training/README.md b/examples/sid_gr/training/README.md index 0261f1e1e..cb9639e4d 100644 --- a/examples/sid_gr/training/README.md +++ b/examples/sid_gr/training/README.md @@ -75,12 +75,20 @@ PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun \ **Note:** Ensure your current working directory is `examples/sid_gr` before running the command. -## Known Limitations +## Attention Backend + +The decoder supports two attention backends, controlled by `use_jagged_flash_attn`: + +| Backend | Flag | Input Format | Mask Format | Dependency | +|---------|------|-------------|-------------|------------| +| Megatron-Core `TransformerBlock` | `False` | Padded dense `[S, B, D]` | Dense `[B, 1, N, N]` | megatron-core | +| `JaggedTransformerBlock` (FA) | `True` | Flattened `[1, total_tokens, D]` (zero padding) | `arbitrary_func` interval encoding | [jiayus/flash-attention (arbitrary_mask branch)](https://github.com/jiayus-nvidia/flash-attention/tree/arbitrary_mask) | -⚠️ **This implementation is under active development.** The current version has not been fully optimized for performance. Known limitations include: +The FA backend flattens all batch sequences into a single sequence (B=1) and encodes the attention pattern via `arbitrary_func`. Block sparsity skips masked regions automatically. The caller is responsible for building the `arbitrary_func` or `attention_mask` tensor. + +## Known Limitations -- **Attention mechanism**: Currently using padded local SDPA (Scaled Dot-Product Attention) implementation in Megatron-Core with explicit attention masks -- **Beam search**: The beam search used during evaluation does not yet support KV cache optimization -- **Performance**: The model performance has not reached optimal levels +**This implementation is under active development.** -We are actively working on addressing these limitations and improving overall efficiency. \ No newline at end of file +- **Beam search**: Does not yet support KV cache optimization +- **Performance**: Not fully optimized \ No newline at end of file From e3421db3f20b23b5ead291df7fe6d9aba990c33f Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Tue, 7 Apr 2026 07:09:08 +0000 Subject: [PATCH 09/10] Fix use_jagged_flash_attn default value. --- examples/sid_gr/model/__init__.py | 2 +- examples/sid_gr/model/gpt_model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/sid_gr/model/__init__.py b/examples/sid_gr/model/__init__.py index 3fb4bb5f6..a96982609 100644 --- a/examples/sid_gr/model/__init__.py +++ b/examples/sid_gr/model/__init__.py @@ -18,7 +18,7 @@ def get_sid_gr_model( top_k_for_generation: int = 10, eval_metrics: Tuple[str, ...] = (), share_lm_head_across_hierarchies: bool = True, - use_jagged_flash_attn: bool = True, + use_jagged_flash_attn: bool = False, ) -> SIDGRModel: sid_gr_model = SIDGRModel( decoder_config=decoder_config, diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index a01b5cceb..bfeccb661 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -105,7 +105,7 @@ def __init__( ] = "learned_absolute", relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, - use_jagged_flash_attn: bool = True, + use_jagged_flash_attn: bool = False, ): super().__init__(config=decoder_config) @@ -185,7 +185,7 @@ def __init__( top_k_for_generation: int = 10, # this is used for eval eval_metrics: Tuple[str, ...] = (), # this is used for eval share_lm_head_across_hierarchies: bool = True, - use_jagged_flash_attn: bool = True, + use_jagged_flash_attn: bool = False, ): super(SIDGRModel, self).__init__(config=decoder_config) assert ( From a796b1888f40c0912d375e330e75d92305b5cc03 Mon Sep 17 00:00:00 2001 From: Runchu Zhao Date: Wed, 8 Apr 2026 03:15:59 +0000 Subject: [PATCH 10/10] Fix some tests bugs. --- examples/sid_gr/model/gpt_model.py | 65 ++++++++++++++++------- examples/sid_gr/tests/test_model_smoke.py | 2 + 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index bfeccb661..da9f7a4fc 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -35,6 +35,10 @@ from modules.gpt_loss_module import GPTSIDLossModule from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from .attention_mask import ( + build_jagged_causal_arbitrary_func, + padded_causal_mask_with_optional_bos, +) from .jagged_flash_attn_block import JaggedTransformerBlock @@ -530,32 +534,55 @@ def decoder_step( input_max_seqlen: int, attention_mask: Optional[torch.Tensor] = None, arbitrary_func: Optional[torch.Tensor] = None, + *, + default_mask_add_bos_to_history: Optional[bool] = None, ) -> torch.Tensor: """ - Input and Output are both jagged. This method only routes — the - caller is responsible for building the mask / arbitrary_func. + Jagged in/out. Routes Megatron-Core vs jagged Flash Attention. + + If both ``attention_mask`` and ``arbitrary_func`` are omitted, builds + the usual causal mask for the active backend (same behavior as before + the explicit-mask refactor). Pass either tensor to override. - * FA path: pass *arbitrary_func* (caller-built). - * mcore path: pass *attention_mask* (dense ``[B, 1, N, N]``). + * FA path: ``arbitrary_func`` (B=1 flattened arbitrary mask). + * mcore path: ``attention_mask`` (dense ``[B, 1, N, N]``). + + ``default_mask_add_bos_to_history`` controls optional-BOS layout for the + built mcore mask: ``None`` uses ``self.add_bos_to_history_for_training`` + (training-style); ``generate`` passes ``False``. """ + if attention_mask is None and arbitrary_func is None: + if self.decoder.use_jagged_flash_attn: + total_tokens = int(input_offsets[-1].item()) + arbitrary_func = build_jagged_causal_arbitrary_func( + input_offsets, total_tokens + ) + else: + add_bos = ( + self.add_bos_to_history_for_training + if default_mask_add_bos_to_history is None + else default_mask_add_bos_to_history + ) + attention_mask = padded_causal_mask_with_optional_bos( + input_offsets, + input_max_seqlen, + add_bos_to_history=add_bos, + bos_interval=self._num_hierarchies, + ) + if self.decoder.use_jagged_flash_attn: - assert ( - arbitrary_func is not None - ), "FA path requires arbitrary_func; caller should build it" + assert arbitrary_func is not None return self.decoder( hidden_states=input_hidden_states, arbitrary_func=arbitrary_func, ) - else: - assert ( - attention_mask is not None - ), "mcore path requires attention_mask; caller should build it" - return self.decoder( - hidden_states=input_hidden_states, - attention_mask=attention_mask, - offsets=input_offsets, - max_seqlen=input_max_seqlen, - ) + assert attention_mask is not None + return self.decoder( + hidden_states=input_hidden_states, + attention_mask=attention_mask, + offsets=input_offsets, + max_seqlen=input_max_seqlen, + ) def forward( self, @@ -575,7 +602,6 @@ def forward( ) history_offsets = batch.features[batch.history_feature_name].offsets() - # 2. decoder step — caller provides the mask / arbitrary_func jagged_output_hidden_states = self.decoder_step( input_hidden_states, input_offsets, @@ -719,13 +745,14 @@ def generate( dtype=input_offsets.dtype, ) - # 2. decoder step — caller provides attention_mask or arbitrary_func + # 2. decoder (mask built in decoder_step when not overridden) jagged_output_hidden_states = self.decoder_step( cated_hidden_states, cated_offsets, cated_max_seqlen, attention_mask=attention_mask, arbitrary_func=arbitrary_func, + default_mask_add_bos_to_history=False, ) # remove history[batchsize * topk_last_step * max(1,i), embedding_dim] _, candidate_hidden_states = triton_split_2D_jagged( diff --git a/examples/sid_gr/tests/test_model_smoke.py b/examples/sid_gr/tests/test_model_smoke.py index d9b7bee7a..8aeb294ce 100644 --- a/examples/sid_gr/tests/test_model_smoke.py +++ b/examples/sid_gr/tests/test_model_smoke.py @@ -222,6 +222,7 @@ def test_model_decoder_step( input_hidden_states.view(-1, hidden_size), input_offsets, input_max_seqlen, + default_mask_add_bos_to_history=False, ).view(batch_size, input_max_seqlen, -1) candidates_logits = output.view(batch_size, input_max_seqlen, -1)[ :, history_max_seqlen:, : @@ -264,6 +265,7 @@ def test_model_decoder_step( prefix_input_hidden_states.view(-1, hidden_size), prefix_input_offsets, prefix_input_max_seqlen, + default_mask_add_bos_to_history=False, ).view(batch_size, prefix_input_max_seqlen, -1) prefix_candidates_logits = prefix_output.view( batch_size, prefix_input_max_seqlen, -1