diff --git a/docker/Dockerfile b/docker/Dockerfile index 84391219e..0b78326b0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -53,6 +53,10 @@ RUN pip install --no-deps tensordict orjson && \ cd torchrec && \ pip install --no-deps . +RUN git clone -b arbitrary_mask https://github.com/jiayus-nvidia/flash-attention.git flash-attention && \ + cd flash-attention && \ + pip install --no-deps -e . + # Install fbgemm_gpu_hstu (package: fbgemm_gpu_hstu, import: hstu) from submodule COPY third_party/FBGEMM /workspace/deps/fbgemm_hstu RUN cd /workspace/deps/fbgemm_hstu/fbgemm_gpu/experimental/hstu && \ diff --git a/examples/sid_gr/beam_search/beam_search.py b/examples/sid_gr/beam_search/beam_search.py index 72274ff5b..e22b4bf43 100644 --- a/examples/sid_gr/beam_search/beam_search.py +++ b/examples/sid_gr/beam_search/beam_search.py @@ -45,6 +45,7 @@ def __init__( self.history_topk_sids: List[torch.Tensor] = [] self.history_accumulate_topk_probs: List[torch.Tensor] = [] self.history_probs: List[torch.Tensor] = [] + self.reset() def propagate( diff --git a/examples/sid_gr/model/__init__.py b/examples/sid_gr/model/__init__.py index 46499783e..a96982609 100644 --- a/examples/sid_gr/model/__init__.py +++ b/examples/sid_gr/model/__init__.py @@ -18,6 +18,7 @@ def get_sid_gr_model( top_k_for_generation: int = 10, eval_metrics: Tuple[str, ...] = (), share_lm_head_across_hierarchies: bool = True, + use_jagged_flash_attn: bool = False, ) -> SIDGRModel: sid_gr_model = SIDGRModel( decoder_config=decoder_config, @@ -25,7 +26,6 @@ def get_sid_gr_model( codebook_sizes=codebook_sizes, num_hierarchies=num_hierarchies, transformer_decoder_layer_spec=get_gpt_decoder_block_spec( - # padding + arbitrary attention mask + Megatron-Core decoder_config, use_transformer_engine=False, arbitrary_attention_mask=True, @@ -35,6 +35,7 @@ def get_sid_gr_model( top_k_for_generation=top_k_for_generation, eval_metrics=eval_metrics, share_lm_head_across_hierarchies=share_lm_head_across_hierarchies, + use_jagged_flash_attn=use_jagged_flash_attn, ) return sid_gr_model diff --git a/examples/sid_gr/model/attention_mask.py b/examples/sid_gr/model/attention_mask.py index 63234f6df..17a63d60c 100644 --- a/examples/sid_gr/model/attention_mask.py +++ b/examples/sid_gr/model/attention_mask.py @@ -207,6 +207,120 @@ def padded_target_aware_causal_mask( return ~valid_attn_mask +def dense_mask_to_arbitrary_func( + valid_mask: torch.Tensor, + seqlen: int, + padding: int = 256, +) -> torch.Tensor: + """ + Convert a dense bool attention mask to flash_attn's interval-based + arbitrary_func tensor. + + For each query position q, the arbitrary_func encodes visible key + positions as a union of intervals: + visible(q) = [0, F0) ∪ [F1, F2) ∪ [F3, F4) ∪ ... + + Args: + valid_mask: [B, N, N] or [B, 1, N, N] bool tensor (True = can attend). + seqlen: sequence length N. + padding: extra padding on last dim (FA convention, default 256). + + Returns: + arbitrary_func: [B, 1, n_func, seqlen + padding] int32 tensor. + """ + if valid_mask.dim() == 4: + valid_mask = valid_mask.squeeze(1) + assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" + + B, N, _ = valid_mask.shape + device = valid_mask.device + + # Detect interval boundaries via transitions + shifted = torch.zeros_like(valid_mask) + shifted[:, :, 1:] = valid_mask[:, :, :-1] + starts = valid_mask & ~shifted # start of each True run + max_intervals = int(starts.sum(dim=-1).max().item()) + n_func = max(2 * max_intervals - 1, 1) + if n_func % 2 == 0: + n_func += 1 + + # When first interval doesn't start at 0, it needs an extra slot. + # Recount: base interval [0, F0) is free only if first run starts at 0. + # Worst case: all intervals need explicit [F_start, F_end) pairs. + # n_func = 2*max_intervals + 1 covers all cases. + n_func = 2 * max_intervals + 1 + if n_func % 2 == 0: + n_func += 1 + + af = torch.zeros(B, 1, n_func, seqlen + padding, dtype=torch.int32, device=device) + + ends_shifted = torch.zeros_like(valid_mask) + ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] + ends = valid_mask & ~ends_shifted # last True position in each run + + for b in range(B): + for q in range(N): + row = valid_mask[b, q] + if not row.any(): + continue + start_pos = starts[b, q].nonzero(as_tuple=False).squeeze(-1) + end_pos = ends[b, q].nonzero(as_tuple=False).squeeze(-1) + 1 + + # F0 encodes [0, F0). If first interval starts at 0, use F0. + # Otherwise F0 stays 0 (empty base interval) and all intervals + # go into the extra slots. + extra_idx = 0 + for iv in range(len(start_pos)): + s, e = start_pos[iv].item(), end_pos[iv].item() + if iv == 0 and s == 0: + af[b, 0, 0, q] = e + else: + af[b, 0, 2 * extra_idx + 1, q] = s + af[b, 0, 2 * extra_idx + 2, q] = e + extra_idx += 1 + + return af + + +def build_jagged_causal_arbitrary_func( + offsets: torch.Tensor, + total_tokens: int, + padding: int = 256, +) -> torch.Tensor: + """ + Build arbitrary_func for flattened jagged causal attention (B=1). + + All batch sequences are concatenated into a single sequence of length + *total_tokens*. Each query at global position *q* in batch element *b* + can attend to keys in ``[offset[b], q+1)`` — standard causal within its + own sequence, invisible to other sequences. + + Args: + offsets: [B+1] cumulative sequence-length offsets. + total_tokens: ``offsets[-1].item()`` — total number of tokens. + padding: FA convention padding on the last dim (default 256). + + Returns: + arbitrary_func: [1, 1, 3, total_tokens + padding] int32 tensor. + """ + device = offsets.device + n_func = 3 # F0=0, single interval [F1, F2) + + af = torch.zeros( + 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device + ) + + positions = torch.arange(total_tokens, device=device) + batch_ids = torch.searchsorted(offsets[1:], positions, right=True) + batch_starts = offsets[batch_ids] + + # visible(q) = [0, 0) ∪ [batch_start, q+1) + af[0, 0, 1, :total_tokens] = batch_starts.to(torch.int32) + af[0, 0, 2, :total_tokens] = (positions + 1).to(torch.int32) + + return af + + if __name__ == "__main__": history_seqlen = torch.tensor([4, 3]).cuda() max_history_seqlen = 6 diff --git a/examples/sid_gr/model/gpt_model.py b/examples/sid_gr/model/gpt_model.py index a59f555ba..da9f7a4fc 100644 --- a/examples/sid_gr/model/gpt_model.py +++ b/examples/sid_gr/model/gpt_model.py @@ -16,7 +16,7 @@ import torch from beam_search.beam_search import BeamSearch -from commons.datasets.gpt_sid_batch import GPTSIDBatch, to_packed_seq_params +from commons.datasets.gpt_sid_batch import GPTSIDBatch from commons.modules.embedding import ShardedEmbedding, ShardedEmbeddingConfig from commons.ops.cuda_ops.JaggedTensorOpFunction import jagged_2D_tensor_concat from commons.ops.length_to_offsets import length_to_complete_offsets @@ -27,7 +27,6 @@ from megatron.core.models.common.embeddings.relative_pos_embedding import ( RelativePositionEmbedding, ) -from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec @@ -37,9 +36,10 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from .attention_mask import ( + build_jagged_causal_arbitrary_func, padded_causal_mask_with_optional_bos, - padded_target_aware_causal_mask, ) +from .jagged_flash_attn_block import JaggedTransformerBlock def _padding_to_dense_and_transpose( @@ -89,7 +89,15 @@ def _transpose_dense_to_jagged( class SIDGRDecoder(MegatronModule): """ - Don't support PP currently. Does not inclu de embedding + Don't support PP currently. Does not include embedding. + + Supports two backend modes controlled by *use_jagged_flash_attn*: + + * **True** (default) — ``JaggedTransformerBlock``: flattens all batch + sequences into one (B=1) and uses jiayus's Flash Attention with + arbitrary_func mask encoding. Zero padding. + * **False** — Megatron-Core ``TransformerBlock``: pads jagged to dense, + uses ``DotProductAttention`` with a dense arbitrary attention mask. """ def __init__( @@ -101,15 +109,16 @@ def __init__( ] = "learned_absolute", relative_attention_num_buckets: int = 32, relative_attention_max_distance: int = 128, + use_jagged_flash_attn: bool = False, ): super().__init__(config=decoder_config) self.config: TransformerConfig = decoder_config self.transformer_decoder_layer_spec: ModuleSpec = transformer_decoder_layer_spec - # TODO, add position encoder self.model_type = ModelType.encoder_or_decoder self.position_embedding_type = position_embedding_type + self.use_jagged_flash_attn = use_jagged_flash_attn self.decoder_relative_pos_emb = RelativePositionEmbedding( bidirectional=False, init_method=self.config.init_method, @@ -117,36 +126,45 @@ def __init__( relative_attention_num_buckets=relative_attention_num_buckets, relative_attention_max_distance=relative_attention_max_distance, ) - self.decoder = TransformerBlock( - config=self.config, - spec=self.transformer_decoder_layer_spec, - ) + + if use_jagged_flash_attn: + self.decoder = JaggedTransformerBlock( + num_layers=self.config.num_layers, + hidden_size=self.config.hidden_size, + num_attention_heads=self.config.num_attention_heads, + ffn_hidden_size=self.config.ffn_hidden_size, + layernorm_epsilon=getattr(self.config, "layernorm_epsilon", 1e-5), + ) + else: + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + ) def forward( self, - hidden_states, - attention_mask: Optional[ - torch.Tensor - ] = None, # decoder attention mask, always causal + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, *, - packed_seq_params: Optional[PackedSeqParams] = None, + offsets: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + arbitrary_func: Optional[torch.Tensor] = None, **kwargs: Any, ) -> torch.Tensor: - attention_bias = None - # if self.position_embedding_type == 'relative': - # # attention bias is supported by cudnn, but not fa. - # # TODO@junzhang add jagged support once we have attention kernels - # query_seq_length = input_max_seqlen - # key_seq_length = query_seq_length - # attention_bias = self.decoder_relative_pos_emb(query_seq_length, key_seq_length) - output = self.decoder( - hidden_states=hidden_states, # query - attention_mask=attention_mask, # attention mask - packed_seq_params=packed_seq_params, # query and kv seqlens - attention_bias=attention_bias, - **kwargs, - ) - return output + if self.use_jagged_flash_attn: + return self.decoder( + hidden_states=hidden_states, + arbitrary_func=arbitrary_func, + ) + else: + # mcore path: expects dense [S, B, D] input + assert offsets is not None and max_seqlen is not None + padded = _padding_to_dense_and_transpose(hidden_states, offsets, max_seqlen) + output = self.decoder( + hidden_states=padded, + attention_mask=attention_mask, + ) + return _transpose_dense_to_jagged(output, offsets, max_seqlen) class SIDGRModel(MegatronModule): @@ -171,6 +189,7 @@ def __init__( top_k_for_generation: int = 10, # this is used for eval eval_metrics: Tuple[str, ...] = (), # this is used for eval share_lm_head_across_hierarchies: bool = True, + use_jagged_flash_attn: bool = False, ): super(SIDGRModel, self).__init__(config=decoder_config) assert ( @@ -199,6 +218,7 @@ def __init__( decoder_config, transformer_decoder_layer_spec, position_embedding_type="relative", + use_jagged_flash_attn=use_jagged_flash_attn, ) self.codebook_sizes = codebook_sizes assert codebook_embedding_config.vocab_size >= sum( @@ -513,66 +533,62 @@ def decoder_step( input_offsets: torch.Tensor, input_max_seqlen: int, attention_mask: Optional[torch.Tensor] = None, - padding_to_dense: bool = True, - add_bos_to_history: bool = False, + arbitrary_func: Optional[torch.Tensor] = None, + *, + default_mask_add_bos_to_history: Optional[bool] = None, ) -> torch.Tensor: """ - Input and Output are both jagged. - attention_mask is used only when padding_to_dense is True. - When attention mask is None, we will construct a causal attention mask if padding_to_dense is True. + Jagged in/out. Routes Megatron-Core vs jagged Flash Attention. + + If both ``attention_mask`` and ``arbitrary_func`` are omitted, builds + the usual causal mask for the active backend (same behavior as before + the explicit-mask refactor). Pass either tensor to override. - We now only support dense input. + * FA path: ``arbitrary_func`` (B=1 flattened arbitrary mask). + * mcore path: ``attention_mask`` (dense ``[B, 1, N, N]``). + + ``default_mask_add_bos_to_history`` controls optional-BOS layout for the + built mcore mask: ``None`` uses ``self.add_bos_to_history_for_training`` + (training-style); ``generate`` passes ``False``. """ - if add_bos_to_history: - assert ( - attention_mask is None - ), "attention mask should be None when adding bos to history" - # TODO, remove the padding. - input_offsets[-1].item() - if padding_to_dense: - decoder_input_hidden_states = _padding_to_dense_and_transpose( - input_hidden_states, - input_offsets, - input_max_seqlen, - ) - packed_seq_params = None - if attention_mask is None: + if attention_mask is None and arbitrary_func is None: + if self.decoder.use_jagged_flash_attn: + total_tokens = int(input_offsets[-1].item()) + arbitrary_func = build_jagged_causal_arbitrary_func( + input_offsets, total_tokens + ) + else: + add_bos = ( + self.add_bos_to_history_for_training + if default_mask_add_bos_to_history is None + else default_mask_add_bos_to_history + ) attention_mask = padded_causal_mask_with_optional_bos( input_offsets, input_max_seqlen, - add_bos_to_history=add_bos_to_history, + add_bos_to_history=add_bos, bos_interval=self._num_hierarchies, ) - else: - # THD still needs batch dimension - # we need to unsqueeze the hidden states to [T, 1, hidden_size] and unsqueeze back after decoder - assert input_hidden_states.dim() == 2, "input_hidden_states should be 2D" - decoder_input_hidden_states = input_hidden_states.unsqueeze(1) - attention_mask = None - packed_seq_params = to_packed_seq_params( - input_offsets, - input_max_seqlen, + + if self.decoder.use_jagged_flash_attn: + assert arbitrary_func is not None + return self.decoder( + hidden_states=input_hidden_states, + arbitrary_func=arbitrary_func, ) - decoder_output_hidden_states = self.decoder( - hidden_states=decoder_input_hidden_states, # input_hidden_states, + assert attention_mask is not None + return self.decoder( + hidden_states=input_hidden_states, attention_mask=attention_mask, - packed_seq_params=packed_seq_params, # we now enforce arbitrary attention mask + dense padding + offsets=input_offsets, + max_seqlen=input_max_seqlen, ) - if padding_to_dense: - output_hidden_states = _transpose_dense_to_jagged( - decoder_output_hidden_states, - input_offsets, - input_max_seqlen, - ) - else: - # remove batch dim if THD - output_hidden_states = decoder_output_hidden_states.squeeze(1) - return output_hidden_states - def forward( self, batch: GPTSIDBatch, + attention_mask: Optional[torch.Tensor] = None, + arbitrary_func: Optional[torch.Tensor] = None, ) -> torch.Tensor: # 1. prepare embeddings: embedding lookup + history, bos and candidate concat ( @@ -586,13 +602,12 @@ def forward( ) history_offsets = batch.features[batch.history_feature_name].offsets() - # 2. decoder step jagged_output_hidden_states = self.decoder_step( input_hidden_states, input_offsets, input_max_seqlen, - attention_mask=None, - add_bos_to_history=self.add_bos_to_history_for_training, + attention_mask=attention_mask, + arbitrary_func=arbitrary_func, ) # 3. postprocess: only keep the candidate hidden states candidate_hidden_states = self._postprocess_output( @@ -636,7 +651,12 @@ def forward( return merged_losses, merged_logits @torch.no_grad - def generate(self, batch: GPTSIDBatch) -> torch.Tensor: + def generate( + self, + batch: GPTSIDBatch, + attention_mask: Optional[torch.Tensor] = None, + arbitrary_func: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Generate the output sids for the given batch. The generation will autogressively generate the output sids with a constrained fixed-width beam search strategy. Args: @@ -645,7 +665,6 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: torch.Tensor: The generated sids. """ - attention_mask: Optional[torch.Tensor] = None # 0. prepare history and bos embeddings. Note that we do not append bos to history. ( history_embeddings, @@ -726,21 +745,14 @@ def generate(self, batch: GPTSIDBatch) -> torch.Tensor: dtype=input_offsets.dtype, ) - # 2. prepare the attention mask - attention_mask = padded_target_aware_causal_mask( - torch.diff(input_offsets), - input_max_seqlen, - 0 if i == 0 else topk_prev_step, - candidate_length, - ) - # 3. we need a decoder step with the concatenated hidden states and offsets. Note that we do not add bos to history for generation. + # 2. decoder (mask built in decoder_step when not overridden) jagged_output_hidden_states = self.decoder_step( cated_hidden_states, cated_offsets, cated_max_seqlen, attention_mask=attention_mask, - padding_to_dense=True, - add_bos_to_history=False, + arbitrary_func=arbitrary_func, + default_mask_add_bos_to_history=False, ) # remove history[batchsize * topk_last_step * max(1,i), embedding_dim] _, candidate_hidden_states = triton_split_2D_jagged( diff --git a/examples/sid_gr/model/jagged_flash_attn_block.py b/examples/sid_gr/model/jagged_flash_attn_block.py new file mode 100644 index 000000000..2c569bd08 --- /dev/null +++ b/examples/sid_gr/model/jagged_flash_attn_block.py @@ -0,0 +1,408 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +JaggedFlashAttnBlock: a self-contained GPT Transformer block that uses +jiayus's Flash Attention with arbitrary_func mask encoding. + +This replaces Megatron-Core's TransformerBlock for inference with +Method A (Incremental Append) beam search. + +Architecture per layer (standard pre-norm GPT): + Input → LayerNorm → QKV Projection → Flash Attention (arbitrary mask) + → Output Projection → Residual + → LayerNorm → FFN → Residual → Output + +Reference: examples/hstu/modules/native_hstu_layer.py +""" +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# flash_attn imports are deferred to runtime (inside functions / __init__) +# so that the module can be imported even without flash_attn installed. + + +def build_block_sparsity( + arbitrary_func: torch.Tensor, + seqlen_q: int, + seqlen_k: int, + headdim: int, +) -> Tuple[Optional[object], Optional[object]]: + """ + Build forward (Q2K) and backward (K2Q) block sparsity indices from an + arbitrary_func tensor. + + Returns (linear_k, linear_q) — either can be None if the CUDA extension + is not available (falls back to dense attention). + """ + try: + import create_block_mask_cuda + from flash_attn.cute.block_sparsity import LinearBlockSparseTensorsTorch + from flash_attn.utils.tile_size import get_arch, get_tile_sizes_by_backend + except ImportError: + return None, None + + arch = get_arch() + fwd_q_block, fwd_kv_block = get_tile_sizes_by_backend( + backend="dsl", + pass_type="forward", + arch=arch, + headdim=headdim, + is_causal=False, + is_local=False, + is_arbitrary=True, + ) + bwd_q_block, bwd_kv_block = get_tile_sizes_by_backend( + backend="dsl", + pass_type="backward", + arch=arch, + headdim=headdim, + is_causal=False, + is_local=False, + is_arbitrary=True, + ) + + ( + k_cnt, + k_off, + k_idx, + k_fcnt, + k_foff, + k_fidx, + ) = create_block_mask_cuda.create_q2k_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=fwd_q_block, + KV_BLOCK_SIZE=fwd_kv_block, + check_q_boundary=True, + ) + linear_k = LinearBlockSparseTensorsTorch( + mask_block_cnt=k_cnt, + mask_block_offset=k_off, + mask_block_idx=k_idx, + full_block_cnt=k_fcnt, + full_block_offset=k_foff, + full_block_idx=k_fidx, + ) + + ( + q_cnt, + q_off, + q_idx, + q_fcnt, + q_foff, + q_fidx, + ) = create_block_mask_cuda.create_k2q_csr_sparse_from_func( + arbitrary_func, + seqlen_q, + seqlen_k, + Q_BLOCK_SIZE=bwd_q_block, + KV_BLOCK_SIZE=bwd_kv_block, + ) + linear_q = LinearBlockSparseTensorsTorch( + mask_block_cnt=q_cnt, + mask_block_offset=q_off, + mask_block_idx=q_idx, + full_block_cnt=q_fcnt, + full_block_offset=q_foff, + full_block_idx=q_fidx, + ) + + return linear_k, linear_q + + +class JaggedGPTLayer(nn.Module): + """ + One Transformer layer with jagged Flash Attention. + + Pre-norm GPT structure: + x = x + Attn(LayerNorm(x)) + x = x + FFN(LayerNorm(x)) + + Q/K/V are produced by a single fused linear (same pattern as HSTU's + ``linear_uvqk``). Flash Attention is called with arbitrary_func for + tree-shaped beam search masks. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.ffn_hidden_size = ffn_hidden_size + + # --- Attention sub-layers --- + self.input_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) + # Fused QKV projection: hidden_size → 3 * hidden_size + self.linear_qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) + # Output projection after attention + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.attn_dropout = nn.Dropout(hidden_dropout) + + # --- FFN sub-layers --- + self.pre_mlp_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) + self.mlp_fc1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.mlp_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + self.mlp_dropout = nn.Dropout(hidden_dropout) + + if activation == "gelu": + self.activation_fn = F.gelu + elif activation == "silu": + self.activation_fn = F.silu + else: + self.activation_fn = F.gelu + + def forward( + self, + hidden_states: torch.Tensor, + arbitrary_func: Optional[torch.Tensor] = None, + linear_k: Optional[object] = None, + linear_q: Optional[object] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch, seqlen, hidden_size] + arbitrary_func: [batch, 1, n_func, seqlen+256] int32 mask encoding. + linear_k: forward block sparsity (Q2K). + linear_q: backward block sparsity (K2Q). + + Returns: + hidden_states: [batch, seqlen, hidden_size] + """ + # ---- Attention block ---- + residual = hidden_states + x = self.input_layernorm(hidden_states) + + # QKV projection: [B, S, H] → [B, S, 3*H] + qkv = self.linear_qkv(x) + # Reshape to [B, S, 3, num_heads, head_dim] and unbind + B, S, _ = qkv.shape + qkv = qkv.view(B, S, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) # each [B, S, num_heads, head_dim] + + # Flash Attention requires fp16/bf16 inputs + from flash_attn.cute.interface import flash_attn_func + + input_dtype = q.dtype + if q.dtype not in (torch.float16, torch.bfloat16): + q, k, v = q.bfloat16(), k.bfloat16(), v.bfloat16() + + if arbitrary_func is not None: + attn_out, _ = flash_attn_func( + q, + k, + v, + softmax_scale=self.head_dim ** (-0.5), + causal=False, + arbitrary=True, + linear_k_block_sparse_tensors=linear_k, + linear_q_block_sparse_tensors=linear_q, + aux_tensors=[arbitrary_func], + ) + else: + attn_out, _ = flash_attn_func( + q, + k, + v, + softmax_scale=self.head_dim ** (-0.5), + causal=True, + ) + + if attn_out.dtype != input_dtype: + attn_out = attn_out.to(input_dtype) + + # attn_out: [B, S, num_heads, head_dim] → [B, S, hidden_size] + attn_out = attn_out.reshape(B, S, self.hidden_size) + attn_out = self.linear_proj(attn_out) + attn_out = self.attn_dropout(attn_out) + hidden_states = residual + attn_out + + # ---- FFN block ---- + residual = hidden_states + x = self.pre_mlp_layernorm(hidden_states) + x = self.mlp_fc1(x) + x = self.activation_fn(x) + x = self.mlp_fc2(x) + x = self.mlp_dropout(x) + hidden_states = residual + x + + return hidden_states + + +class JaggedFlashAttnBlock(nn.Module): + """ + A stack of JaggedGPTLayers — the GPT decoder block using jiayus's + Flash Attention with arbitrary_func mask encoding. + + This module owns its own weights (not shared with Megatron-Core). + It is used in place of Megatron's TransformerBlock for inference + with Method A beam search. + + Usage:: + + block = JaggedFlashAttnBlock( + num_layers=4, + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + ) + # padded input: [B, S, D] + output = block(hidden_states, arbitrary_func=af, seqlen=S) + """ + + def __init__( + self, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.hidden_size = hidden_size + self.head_dim = hidden_size // num_attention_heads + self.layers = nn.ModuleList( + [ + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + activation=activation, + ) + for _ in range(num_layers) + ] + ) + self.final_layernorm = nn.LayerNorm(hidden_size, eps=layernorm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + arbitrary_func: Optional[torch.Tensor] = None, + seqlen: Optional[int] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states: [batch, seqlen, hidden_size] padded input. + arbitrary_func: [batch, 1, n_func, seqlen+256] int32 mask tensor. + If None, uses standard causal attention. + seqlen: sequence length (used for block sparsity construction). + If None, inferred from hidden_states.shape[1]. + + Returns: + hidden_states: [batch, seqlen, hidden_size] + """ + if seqlen is None: + seqlen = hidden_states.shape[1] + + # Build block sparsity from arbitrary_func (once per forward, shared by all layers) + linear_k, linear_q = None, None + if arbitrary_func is not None: + linear_k, linear_q = build_block_sparsity( + arbitrary_func, seqlen, seqlen, self.head_dim + ) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + arbitrary_func=arbitrary_func, + linear_k=linear_k, + linear_q=linear_q, + ) + + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class JaggedTransformerBlock(nn.Module): + """ + Wrapper that accepts jagged (variable-length) hidden states and a + pre-built arbitrary_func tensor in the flattened (B=1) coordinate space. + + All batch sequences are concatenated into a single sequence of length + *total_tokens* (no padding). The arbitrary_func encodes both the + block-diagonal batch isolation and the desired attention pattern + (causal, target-grouped, etc.). + + Internally: + 1. Reshape jagged [total_tokens, D] → [1, total_tokens, D] + 2. Forward through JaggedFlashAttnBlock (FA with arbitrary mask) + 3. Reshape [1, total_tokens, D] → [total_tokens, D] + + This is intended to replace Megatron-Core's TransformerBlock in + SIDGRDecoder. + """ + + def __init__( + self, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + ffn_hidden_size: int, + layernorm_epsilon: float = 1e-5, + hidden_dropout: float = 0.0, + activation: str = "gelu", + ): + super().__init__() + self.block = JaggedFlashAttnBlock( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + activation=activation, + ) + + def forward( + self, + hidden_states: torch.Tensor, + arbitrary_func: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states: jagged [total_tokens, hidden_size]. + arbitrary_func: [1, 1, n_func, total_tokens + pad] int32 tensor + in flattened (B=1) coordinate space, encoding both batch + isolation and the attention pattern. + + Returns: + jagged output [total_tokens, hidden_size]. + """ + total_tokens = hidden_states.shape[0] + + # [total_tokens, D] → [1, total_tokens, D] + flat_input = hidden_states.unsqueeze(0) + + output = self.block( + flat_input, arbitrary_func=arbitrary_func, seqlen=total_tokens + ) + + # [1, total_tokens, D] → [total_tokens, D] + return output.squeeze(0) diff --git a/examples/sid_gr/model/mcore_model_specs.py b/examples/sid_gr/model/mcore_model_specs.py index d5a57392a..b4ad2ed02 100644 --- a/examples/sid_gr/model/mcore_model_specs.py +++ b/examples/sid_gr/model/mcore_model_specs.py @@ -375,6 +375,7 @@ def get_gpt_decoder_block_spec( ), "arbitrary attention mask is only supported with Megatron-Core modules" # Layer specs. + # dense_layer_spec = ( _get_gpt_layer_with_transformer_engine_spec( num_experts=None, diff --git a/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py new file mode 100644 index 000000000..d32ca1882 --- /dev/null +++ b/examples/sid_gr/tests/test_dense_mask_to_arbitrary_func.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for dense_mask_to_arbitrary_func(): verifies that converting a dense +[B,N,N] mask to arbitrary_func interval encoding preserves mask semantics. +""" +import os +import sys + +import pytest +import torch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from attention_mask import ( + build_jagged_causal_arbitrary_func, + dense_mask_to_arbitrary_func, + padded_target_aware_causal_mask, +) + +sys.path.pop(0) + + +def dense_mask_to_jagged_arbitrary_func( + valid_mask: torch.Tensor, + offsets: torch.Tensor, + total_tokens: int, + padding: int = 256, +) -> torch.Tensor: + """ + Test utility: convert per-batch dense mask [B, N, N] to a flattened + (B=1) arbitrary_func [1, 1, n_func, total_tokens + padding]. + + Maps each row from per-batch padded coordinates to global (flattened) + coordinates. Used in tests to build arbitrary_func from dense masks + so the FA path can be compared against the mcore/reference path. + """ + if valid_mask.dim() == 4: + valid_mask = valid_mask.squeeze(1) + assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" + + B, N, _ = valid_mask.shape + device = valid_mask.device + + shifted = torch.zeros_like(valid_mask) + shifted[:, :, 1:] = valid_mask[:, :, :-1] + starts = valid_mask & ~shifted + + ends_shifted = torch.zeros_like(valid_mask) + ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] + ends = valid_mask & ~ends_shifted + + max_intervals = int(starts.sum(dim=-1).max().item()) + n_func = max(2 * max_intervals + 1, 3) + if n_func % 2 == 0: + n_func += 1 + + af = torch.zeros( + 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device + ) + + for b in range(B): + batch_start = offsets[b].item() + batch_end = offsets[b + 1].item() + seq_len = batch_end - batch_start + + for local_q in range(seq_len): + global_q = batch_start + local_q + row = valid_mask[b, local_q, :seq_len] + if not row.any(): + continue + start_pos = starts[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + end_pos = ends[b, local_q, :seq_len].nonzero(as_tuple=False).squeeze(-1) + 1 + for iv in range(len(start_pos)): + s = start_pos[iv].item() + batch_start + e = end_pos[iv].item() + batch_start + af[0, 0, 2 * iv + 1, global_q] = s + af[0, 0, 2 * iv + 2, global_q] = e + + return af + + +def arbitrary_func_to_dense(af, seqlen_q, seqlen_k): + """Expand arbitrary_func back to dense [B, seqlen_q, seqlen_k] bool mask.""" + B, n_func = af.shape[0], af.shape[2] + mask = torch.zeros(B, seqlen_q, seqlen_k, dtype=torch.bool, device=af.device) + kv_idx = torch.arange(seqlen_k, device=af.device) + for b in range(B): + for q in range(seqlen_q): + f0 = af[b, 0, 0, q].item() + row_mask = kv_idx < f0 + for iv in range(n_func // 2): + f_start = af[b, 0, 2 * iv + 1, q].item() + f_end = af[b, 0, 2 * iv + 2, q].item() + row_mask = row_mask | ((kv_idx >= f_start) & (kv_idx < f_end)) + mask[b, q] = row_mask + return mask + + +class TestDenseMaskToArbitraryFunc: + def test_causal_mask(self): + N, B = 16, 1 + valid = torch.tril(torch.ones(B, N, N, dtype=torch.bool, device="cuda")) + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_full_attention(self): + N, B = 8, 2 + valid = torch.ones(B, N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_empty_mask(self): + N, B = 8, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_block_diagonal(self): + N, B = 8, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0, :4, :4] = True + valid[0, 4:, 4:] = True + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + @pytest.mark.parametrize("beam_width", [2, 3]) + @pytest.mark.parametrize("candidate_len", [1, 3]) + def test_target_aware_causal_mask(self, beam_width, candidate_len): + hist_lens = torch.tensor([6, 4], device="cuda") + inverted = padded_target_aware_causal_mask( + hist_lens, 6, beam_width, candidate_len + ) + valid = ~inverted + N = valid.shape[-1] + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid.squeeze(1), arbitrary_func_to_dense(af, N, N)) + + def test_mask_with_gaps(self): + N, B = 10, 1 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0, 5, 0:3] = True + valid[0, 5, 5:7] = True + valid[0, 5, 9] = True + af = dense_mask_to_arbitrary_func(valid, N) + assert torch.equal(valid, arbitrary_func_to_dense(af, N, N)) + + def test_4d_input(self): + N, B = 8, 1 + valid_4d = torch.tril(torch.ones(B, 1, N, N, dtype=torch.bool, device="cuda")) + af = dense_mask_to_arbitrary_func(valid_4d, N) + assert torch.equal(valid_4d.squeeze(1), arbitrary_func_to_dense(af, N, N)) + + def test_batch_independence(self): + N, B = 8, 2 + valid = torch.zeros(B, N, N, dtype=torch.bool, device="cuda") + valid[0] = torch.tril(torch.ones(N, N, dtype=torch.bool, device="cuda")) + valid[1] = torch.ones(N, N, dtype=torch.bool, device="cuda") + af = dense_mask_to_arbitrary_func(valid, N) + recon = arbitrary_func_to_dense(af, N, N) + assert torch.equal(valid[0], recon[0]) + assert torch.equal(valid[1], recon[1]) + + +class TestJaggedFlattenedArbitraryFunc: + """Tests for the B=1 flattened arbitrary_func builders.""" + + @staticmethod + def _build_expected_jagged_causal(offsets): + """Build the expected [1, total, total] causal block-diagonal mask.""" + total = offsets[-1].item() + device = offsets.device + expected = torch.zeros(total, total, dtype=torch.bool, device=device) + B = offsets.size(0) - 1 + for b in range(B): + s, e = offsets[b].item(), offsets[b + 1].item() + block = torch.tril( + torch.ones(e - s, e - s, dtype=torch.bool, device=device) + ) + expected[s:e, s:e] = block + return expected + + def test_jagged_causal_basic(self): + offsets = torch.tensor([0, 4, 7, 10], device="cuda") + total = 10 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + def test_jagged_causal_single_batch(self): + offsets = torch.tensor([0, 6], device="cuda") + total = 6 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = torch.tril(torch.ones(6, 6, dtype=torch.bool, device="cuda")) + assert torch.equal(expected, recon) + + def test_jagged_causal_uneven_lengths(self): + offsets = torch.tensor([0, 2, 8, 9], device="cuda") + total = 9 + af = build_jagged_causal_arbitrary_func(offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + def test_dense_to_jagged_causal(self): + """dense_mask_to_jagged_arbitrary_func should match build_jagged_causal for causal masks.""" + offsets = torch.tensor([0, 3, 7], device="cuda") + B, total = 2, 7 + max_seqlen = 4 + per_batch = torch.zeros( + B, max_seqlen, max_seqlen, dtype=torch.bool, device="cuda" + ) + for b in range(B): + sl = (offsets[b + 1] - offsets[b]).item() + per_batch[b, :sl, :sl] = torch.tril( + torch.ones(sl, sl, dtype=torch.bool, device="cuda") + ) + af = dense_mask_to_jagged_arbitrary_func(per_batch, offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + expected = self._build_expected_jagged_causal(offsets) + assert torch.equal(expected, recon) + + @pytest.mark.parametrize("beam_width", [2, 3]) + @pytest.mark.parametrize("candidate_len", [1, 3]) + def test_dense_to_jagged_target_grouped(self, beam_width, candidate_len): + """Verify target-grouped masks survive the jagged conversion roundtrip.""" + B = 2 + hist_lens = torch.tensor([5, 3], device="cuda") + max_hist = 5 + inverted = padded_target_aware_causal_mask( + hist_lens, max_hist, beam_width, candidate_len + ) + valid = ~inverted # [B, 1, N, N] + valid.shape[-1] + total_per_batch = (hist_lens + beam_width * candidate_len).tolist() + offsets = torch.tensor( + [0] + [sum(total_per_batch[: i + 1]) for i in range(B)], + device="cuda", + ) + total = offsets[-1].item() + af = dense_mask_to_jagged_arbitrary_func(valid, offsets, total) + recon = arbitrary_func_to_dense(af, total, total).squeeze(0) + + # Build expected flattened mask from per-batch dense mask + expected = torch.zeros(total, total, dtype=torch.bool, device="cuda") + valid_3d = valid.squeeze(1) + for b in range(B): + s = offsets[b].item() + sl = total_per_batch[b] + expected[s : s + sl, s : s + sl] = valid_3d[b, :sl, :sl] + + assert torch.equal(expected, recon) diff --git a/examples/sid_gr/tests/test_jagged_flash_attn_block.py b/examples/sid_gr/tests/test_jagged_flash_attn_block.py new file mode 100644 index 000000000..d3979409b --- /dev/null +++ b/examples/sid_gr/tests/test_jagged_flash_attn_block.py @@ -0,0 +1,468 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for JaggedFlashAttnBlock / JaggedGPTLayer. + +Test strategy (same pattern as HSTU's test_hstu_layer.py): + 1. Build a reference implementation using PyTorch's scaled_dot_product_attention + 2. Build JaggedGPTLayer with the same weights + 3. Run the same input, compare outputs + +Tests: + - Smoke: forward pass runs without error, output shape correct + - Causal: causal mask produces correct output vs PyTorch reference + - Backward: gradients flow correctly +""" + +import os +import sys + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "model")) +from attention_mask import build_jagged_causal_arbitrary_func +from jagged_flash_attn_block import ( + JaggedFlashAttnBlock, + JaggedGPTLayer, + JaggedTransformerBlock, +) + +sys.path.pop(0) + +try: + from flash_attn.cute.interface import flash_attn_func # noqa: F401 + + HAS_FLASH_ATTN = True +except ImportError: + HAS_FLASH_ATTN = False + + +# --------------------------------------------------------------------------- +# PyTorch reference: single GPT layer with standard attention +# --------------------------------------------------------------------------- +class ReferenceGPTLayer(nn.Module): + """Minimal GPT layer using PyTorch's SDPA for correctness comparison.""" + + def __init__(self, hidden_size, num_heads, ffn_hidden_size, eps=1e-5): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + self.input_layernorm = nn.LayerNorm(hidden_size, eps=eps) + self.linear_qkv = nn.Linear(hidden_size, 3 * hidden_size, bias=False) + self.linear_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.pre_mlp_layernorm = nn.LayerNorm(hidden_size, eps=eps) + self.mlp_fc1 = nn.Linear(hidden_size, ffn_hidden_size, bias=False) + self.mlp_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False) + + def forward(self, hidden_states, attn_mask=None, is_causal=False): + residual = hidden_states + x = self.input_layernorm(hidden_states) + B, S, _ = x.shape + + qkv = self.linear_qkv(x).view(B, S, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) + q = q.transpose(1, 2) # [B, H, S, D] + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + attn_out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=is_causal + ) + attn_out = attn_out.transpose(1, 2).reshape(B, S, self.hidden_size) + attn_out = self.linear_proj(attn_out) + hidden_states = residual + attn_out + + residual = hidden_states + x = self.pre_mlp_layernorm(hidden_states) + x = self.mlp_fc1(x) + x = F.gelu(x) + x = self.mlp_fc2(x) + hidden_states = residual + x + + return hidden_states + + +def _copy_weights(src: nn.Module, dst: nn.Module): + """Copy all matching weights from src to dst.""" + dst.load_state_dict(src.state_dict(), strict=False) + src_dict = dict(src.named_parameters()) + for name, param in dst.named_parameters(): + if name in src_dict: + param.data.copy_(src_dict[name].data) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerSmoke: + """Basic smoke tests that don't require flash_attn.""" + + @pytest.mark.parametrize("batch_size", [1, 2]) + @pytest.mark.parametrize("seqlen", [16, 64]) + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4), (512, 8)]) + def test_forward_shape(self, batch_size, seqlen, hidden_size, num_heads): + """Output shape should match input shape.""" + ffn_size = hidden_size * 4 + layer = JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + ).cuda() + + x = torch.randn(batch_size, seqlen, hidden_size, device="cuda") + out = layer(x) + assert out.shape == (batch_size, seqlen, hidden_size) + + def test_block_forward_shape(self): + """JaggedFlashAttnBlock stacks layers correctly.""" + block = ( + JaggedFlashAttnBlock( + num_layers=2, + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + ) + .cuda() + .bfloat16() + ) + + x = torch.randn(2, 32, 256, device="cuda", dtype=torch.bfloat16) + out = block(x) + assert out.shape == (2, 32, 256) + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerCorrectness: + """Compare JaggedGPTLayer (FA) against PyTorch reference.""" + + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4)]) + @pytest.mark.parametrize("seqlen", [16, 32]) + def test_causal_matches_reference(self, hidden_size, num_heads, seqlen): + """ + With causal=True (no arbitrary_func), JaggedGPTLayer should produce + the same output as the PyTorch reference (within bf16 precision). + """ + B = 2 + ffn_size = hidden_size * 4 + torch.manual_seed(42) + + ref_layer = ( + ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + ) + test_layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) + + test_layer.load_state_dict(ref_layer.state_dict()) + + x = torch.randn(B, seqlen, hidden_size, device="cuda", dtype=torch.bfloat16) + + with torch.no_grad(): + ref_out = ref_layer(x, is_causal=True) + test_out = test_layer(x, arbitrary_func=None) + + torch.testing.assert_close(test_out, ref_out, atol=5e-2, rtol=5e-2) + + @pytest.mark.parametrize("hidden_size,num_heads", [(256, 4)]) + def test_arbitrary_causal_matches_standard_causal(self, hidden_size, num_heads): + """ + An arbitrary_func encoding a causal mask should produce the same + result as the built-in causal=True path. + """ + B, S = 1, 32 + ffn_size = hidden_size * 4 + torch.manual_seed(42) + + layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) + + x = torch.randn(B, S, hidden_size, device="cuda", dtype=torch.bfloat16) + + # Build causal arbitrary_func: F0[i] = i+1 + n_func = 1 + af = torch.zeros(B, 1, n_func, S + 256, dtype=torch.int32, device="cuda") + for i in range(S): + af[:, :, 0, i] = i + 1 + + with torch.no_grad(): + out_causal = layer(x, arbitrary_func=None) + out_arb = layer(x, arbitrary_func=af) + + torch.testing.assert_close(out_arb, out_causal, atol=5e-2, rtol=5e-2) + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedGPTLayerBackward: + """Verify gradients flow correctly.""" + + def test_backward_runs(self): + """Forward + backward should not error.""" + layer = ( + JaggedGPTLayer( + hidden_size=256, + num_attention_heads=4, + ffn_hidden_size=1024, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) + + x = torch.randn( + 2, 16, 256, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + out = layer(x) + loss = out.sum() + loss.backward() + + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_backward_gradient_correctness(self): + """ + Gradient of JaggedGPTLayer (causal) should match the PyTorch reference. + """ + hidden_size, num_heads = 256, 4 + ffn_size = 1024 + B, S = 2, 16 + torch.manual_seed(42) + + ref_layer = ( + ReferenceGPTLayer(hidden_size, num_heads, ffn_size).cuda().bfloat16() + ) + test_layer = ( + JaggedGPTLayer( + hidden_size=hidden_size, + num_attention_heads=num_heads, + ffn_hidden_size=ffn_size, + hidden_dropout=0.0, + ) + .cuda() + .bfloat16() + ) + test_layer.load_state_dict(ref_layer.state_dict()) + + x_ref = torch.randn( + B, S, hidden_size, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + x_test = x_ref.detach().clone().requires_grad_(True) + + ref_out = ref_layer(x_ref, is_causal=True) + test_out = test_layer(x_test, arbitrary_func=None) + + dout = torch.randn_like(ref_out) + ref_out.backward(dout) + test_out.backward(dout) + + torch.testing.assert_close(x_test.grad, x_ref.grad, atol=5e-2, rtol=5e-2) + + +# --------------------------------------------------------------------------- +# JaggedTransformerBlock (B=1 flatten) vs padded reference +# --------------------------------------------------------------------------- +class ReferenceBlock(nn.Module): + """Stack of ReferenceGPTLayers + final layernorm for comparison.""" + + def __init__(self, num_layers, hidden_size, num_heads, ffn_hidden_size, eps=1e-5): + super().__init__() + self.layers = nn.ModuleList( + [ + ReferenceGPTLayer(hidden_size, num_heads, ffn_hidden_size, eps) + for _ in range(num_layers) + ] + ) + self.final_layernorm = nn.LayerNorm(hidden_size, eps=eps) + + def forward(self, x, attn_mask=None, is_causal=False): + for layer in self.layers: + x = layer(x, attn_mask=attn_mask, is_causal=is_causal) + return self.final_layernorm(x) + + +def _build_padded_causal_block_diagonal_mask(offsets, max_seqlen): + """Build [B, 1, S, S] causal + padding mask (True = can attend).""" + B = offsets.size(0) - 1 + device = offsets.device + mask = torch.zeros(B, 1, max_seqlen, max_seqlen, dtype=torch.bool, device=device) + for b in range(B): + sl = (offsets[b + 1] - offsets[b]).item() + mask[b, 0, :sl, :sl] = torch.tril( + torch.ones(sl, sl, dtype=torch.bool, device=device) + ) + return mask + + +@pytest.mark.skipif(not HAS_FLASH_ATTN, reason="flash_attn not installed") +class TestJaggedTransformerBlockVsPadded: + """ + Compare the B=1 flattened FA path (JaggedTransformerBlock) against the + padded PyTorch reference to make sure the flatten + arbitrary_func + approach produces correct outputs. + """ + + @pytest.mark.parametrize("num_layers", [1, 2]) + def test_causal_flatten_vs_padded(self, num_layers): + """ + For jagged causal attention, the flattened (B=1) FA path should + produce the same per-sequence output as the padded reference. + """ + H, NH, FFN = 256, 4, 1024 + seq_lens = [5, 8, 3] + offsets = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seq_lens), 0).tolist()), + device="cuda", + ) + total_tokens = offsets[-1].item() + max_seqlen = max(seq_lens) + B = len(seq_lens) + torch.manual_seed(123) + + ref_block = ReferenceBlock(num_layers, H, NH, FFN).cuda().bfloat16() + test_block = ( + JaggedTransformerBlock( + num_layers=num_layers, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) + + # Copy weights: test_block.block.layers[i] ↔ ref_block.layers[i] + for i in range(num_layers): + test_block.block.layers[i].load_state_dict(ref_block.layers[i].state_dict()) + test_block.block.final_layernorm.load_state_dict( + ref_block.final_layernorm.state_dict() + ) + + # Build jagged input [total_tokens, H] + jagged_input = torch.randn(total_tokens, H, device="cuda", dtype=torch.bfloat16) + + # --- FA path: flatten to B=1, build arbitrary_func --- + arbitrary_func = build_jagged_causal_arbitrary_func(offsets, total_tokens) + with torch.no_grad(): + fa_output = test_block(jagged_input, arbitrary_func=arbitrary_func) + + # --- Reference path: pad each sequence, run per-batch --- + with torch.no_grad(): + for b in range(B): + sl = seq_lens[b] + s = offsets[b].item() + padded_in = torch.zeros( + 1, max_seqlen, H, device="cuda", dtype=torch.bfloat16 + ) + padded_in[0, :sl] = jagged_input[s : s + sl] + ref_out = ref_block(padded_in, is_causal=True) + fa_seq = fa_output[s : s + sl] + ref_seq = ref_out[0, :sl] + torch.testing.assert_close( + fa_seq, + ref_seq, + atol=5e-2, + rtol=5e-2, + msg=f"Mismatch in batch element {b}", + ) + + def test_flatten_output_shape(self): + """Output shape should be [total_tokens, hidden_size].""" + H, NH, FFN = 128, 2, 512 + offsets = torch.tensor([0, 4, 9], device="cuda") + total = 9 + + block = ( + JaggedTransformerBlock( + num_layers=1, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) + + x = torch.randn(total, H, device="cuda", dtype=torch.bfloat16) + af = build_jagged_causal_arbitrary_func(offsets, total) + + with torch.no_grad(): + out = block(x, arbitrary_func=af) + + assert out.shape == (total, H) + + def test_batch_isolation(self): + """ + Sequences in different batch elements must not affect each other. + Running [A, B] flattened should give the same per-sequence output + as running [A] and [B] separately. + """ + H, NH, FFN = 256, 4, 1024 + torch.manual_seed(77) + + block = ( + JaggedTransformerBlock( + num_layers=1, + hidden_size=H, + num_attention_heads=NH, + ffn_hidden_size=FFN, + ) + .cuda() + .bfloat16() + ) + + seq_a = torch.randn(5, H, device="cuda", dtype=torch.bfloat16) + seq_b = torch.randn(7, H, device="cuda", dtype=torch.bfloat16) + + # Run combined [A, B] + combined = torch.cat([seq_a, seq_b], dim=0) + offsets_ab = torch.tensor([0, 5, 12], device="cuda") + af_ab = build_jagged_causal_arbitrary_func(offsets_ab, 12) + with torch.no_grad(): + out_ab = block(combined, arbitrary_func=af_ab) + + # Run A alone + offsets_a = torch.tensor([0, 5], device="cuda") + af_a = build_jagged_causal_arbitrary_func(offsets_a, 5) + with torch.no_grad(): + out_a = block(seq_a, arbitrary_func=af_a) + + # Run B alone + offsets_b = torch.tensor([0, 7], device="cuda") + af_b = build_jagged_causal_arbitrary_func(offsets_b, 7) + with torch.no_grad(): + out_b = block(seq_b, arbitrary_func=af_b) + + torch.testing.assert_close(out_ab[:5], out_a, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(out_ab[5:], out_b, atol=1e-3, rtol=1e-3) diff --git a/examples/sid_gr/tests/test_model_smoke.py b/examples/sid_gr/tests/test_model_smoke.py index d9b7bee7a..8aeb294ce 100644 --- a/examples/sid_gr/tests/test_model_smoke.py +++ b/examples/sid_gr/tests/test_model_smoke.py @@ -222,6 +222,7 @@ def test_model_decoder_step( input_hidden_states.view(-1, hidden_size), input_offsets, input_max_seqlen, + default_mask_add_bos_to_history=False, ).view(batch_size, input_max_seqlen, -1) candidates_logits = output.view(batch_size, input_max_seqlen, -1)[ :, history_max_seqlen:, : @@ -264,6 +265,7 @@ def test_model_decoder_step( prefix_input_hidden_states.view(-1, hidden_size), prefix_input_offsets, prefix_input_max_seqlen, + default_mask_add_bos_to_history=False, ).view(batch_size, prefix_input_max_seqlen, -1) prefix_candidates_logits = prefix_output.view( batch_size, prefix_input_max_seqlen, -1 diff --git a/examples/sid_gr/training/README.md b/examples/sid_gr/training/README.md index 0261f1e1e..cb9639e4d 100644 --- a/examples/sid_gr/training/README.md +++ b/examples/sid_gr/training/README.md @@ -75,12 +75,20 @@ PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun \ **Note:** Ensure your current working directory is `examples/sid_gr` before running the command. -## Known Limitations +## Attention Backend + +The decoder supports two attention backends, controlled by `use_jagged_flash_attn`: + +| Backend | Flag | Input Format | Mask Format | Dependency | +|---------|------|-------------|-------------|------------| +| Megatron-Core `TransformerBlock` | `False` | Padded dense `[S, B, D]` | Dense `[B, 1, N, N]` | megatron-core | +| `JaggedTransformerBlock` (FA) | `True` | Flattened `[1, total_tokens, D]` (zero padding) | `arbitrary_func` interval encoding | [jiayus/flash-attention (arbitrary_mask branch)](https://github.com/jiayus-nvidia/flash-attention/tree/arbitrary_mask) | -⚠️ **This implementation is under active development.** The current version has not been fully optimized for performance. Known limitations include: +The FA backend flattens all batch sequences into a single sequence (B=1) and encodes the attention pattern via `arbitrary_func`. Block sparsity skips masked regions automatically. The caller is responsible for building the `arbitrary_func` or `attention_mask` tensor. + +## Known Limitations -- **Attention mechanism**: Currently using padded local SDPA (Scaled Dot-Product Attention) implementation in Megatron-Core with explicit attention masks -- **Beam search**: The beam search used during evaluation does not yet support KV cache optimization -- **Performance**: The model performance has not reached optimal levels +**This implementation is under active development.** -We are actively working on addressing these limitations and improving overall efficiency. \ No newline at end of file +- **Beam search**: Does not yet support KV cache optimization +- **Performance**: Not fully optimized \ No newline at end of file