Skip to content
4 changes: 4 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ RUN pip install --no-deps tensordict orjson && \
cd torchrec && \
pip install --no-deps .

RUN git clone -b arbitrary_mask https://github.com/jiayus-nvidia/flash-attention.git flash-attention && \
cd flash-attention && \
pip install --no-deps -e .
Comment thread
z52527 marked this conversation as resolved.

# Install fbgemm_gpu_hstu (package: fbgemm_gpu_hstu, import: hstu) from submodule
COPY third_party/FBGEMM /workspace/deps/fbgemm_hstu
RUN cd /workspace/deps/fbgemm_hstu/fbgemm_gpu/experimental/hstu && \
Expand Down
1 change: 1 addition & 0 deletions examples/sid_gr/beam_search/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
self.history_topk_sids: List[torch.Tensor] = []
self.history_accumulate_topk_probs: List[torch.Tensor] = []
self.history_probs: List[torch.Tensor] = []

self.reset()

def propagate(
Expand Down
3 changes: 2 additions & 1 deletion examples/sid_gr/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def get_sid_gr_model(
top_k_for_generation: int = 10,
eval_metrics: Tuple[str, ...] = (),
share_lm_head_across_hierarchies: bool = True,
use_jagged_flash_attn: bool = False,
) -> SIDGRModel:
sid_gr_model = SIDGRModel(
decoder_config=decoder_config,
codebook_embedding_config=codebook_embedding_config,
codebook_sizes=codebook_sizes,
num_hierarchies=num_hierarchies,
transformer_decoder_layer_spec=get_gpt_decoder_block_spec(
# padding + arbitrary attention mask + Megatron-Core
decoder_config,
use_transformer_engine=False,
arbitrary_attention_mask=True,
Expand All @@ -35,6 +35,7 @@ def get_sid_gr_model(
top_k_for_generation=top_k_for_generation,
eval_metrics=eval_metrics,
share_lm_head_across_hierarchies=share_lm_head_across_hierarchies,
use_jagged_flash_attn=use_jagged_flash_attn,
)

return sid_gr_model
114 changes: 114 additions & 0 deletions examples/sid_gr/model/attention_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,120 @@ def padded_target_aware_causal_mask(
return ~valid_attn_mask


def dense_mask_to_arbitrary_func(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this func used? better to make it a fused triton function for better perf

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This func only used for testing.

valid_mask: torch.Tensor,
seqlen: int,
padding: int = 256,
) -> torch.Tensor:
"""
Convert a dense bool attention mask to flash_attn's interval-based
arbitrary_func tensor.

For each query position q, the arbitrary_func encodes visible key
positions as a union of intervals:
visible(q) = [0, F0) ∪ [F1, F2) ∪ [F3, F4) ∪ ...

Args:
valid_mask: [B, N, N] or [B, 1, N, N] bool tensor (True = can attend).
seqlen: sequence length N.
padding: extra padding on last dim (FA convention, default 256).

Returns:
arbitrary_func: [B, 1, n_func, seqlen + padding] int32 tensor.
"""
if valid_mask.dim() == 4:
valid_mask = valid_mask.squeeze(1)
assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}"

B, N, _ = valid_mask.shape
device = valid_mask.device

# Detect interval boundaries via transitions
shifted = torch.zeros_like(valid_mask)
shifted[:, :, 1:] = valid_mask[:, :, :-1]
starts = valid_mask & ~shifted # start of each True run
max_intervals = int(starts.sum(dim=-1).max().item())
n_func = max(2 * max_intervals - 1, 1)
if n_func % 2 == 0:
n_func += 1
Comment on lines +242 to +245
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead code — first n_func computation is immediately overwritten

Lines 243-245 compute n_func using max(2 * max_intervals - 1, 1), but this value is unconditionally overwritten by the identical-pattern computation at lines 251-253. The first block (and its parity-adjustment if n_func % 2 == 0) is dead code that adds confusion without effect.

Suggested change
max_intervals = int(starts.sum(dim=-1).max().item())
n_func = max(2 * max_intervals - 1, 1)
if n_func % 2 == 0:
n_func += 1
max_intervals = int(starts.sum(dim=-1).max().item())

The surviving formula (2 * max_intervals + 1) already ensures n_func is always odd and large enough to hold all intervals, so the earlier attempt can be removed entirely.


# When first interval doesn't start at 0, it needs an extra slot.
# Recount: base interval [0, F0) is free only if first run starts at 0.
# Worst case: all intervals need explicit [F_start, F_end) pairs.
# n_func = 2*max_intervals + 1 covers all cases.
n_func = 2 * max_intervals + 1
if n_func % 2 == 0:
n_func += 1

af = torch.zeros(B, 1, n_func, seqlen + padding, dtype=torch.int32, device=device)

ends_shifted = torch.zeros_like(valid_mask)
ends_shifted[:, :, :-1] = valid_mask[:, :, 1:]
ends = valid_mask & ~ends_shifted # last True position in each run

for b in range(B):
for q in range(N):
row = valid_mask[b, q]
if not row.any():
continue
start_pos = starts[b, q].nonzero(as_tuple=False).squeeze(-1)
end_pos = ends[b, q].nonzero(as_tuple=False).squeeze(-1) + 1

# F0 encodes [0, F0). If first interval starts at 0, use F0.
# Otherwise F0 stays 0 (empty base interval) and all intervals
# go into the extra slots.
extra_idx = 0
for iv in range(len(start_pos)):
s, e = start_pos[iv].item(), end_pos[iv].item()
if iv == 0 and s == 0:
af[b, 0, 0, q] = e
else:
af[b, 0, 2 * extra_idx + 1, q] = s
af[b, 0, 2 * extra_idx + 2, q] = e
extra_idx += 1

return af


def build_jagged_causal_arbitrary_func(
offsets: torch.Tensor,
total_tokens: int,
padding: int = 256,
) -> torch.Tensor:
"""
Build arbitrary_func for flattened jagged causal attention (B=1).

All batch sequences are concatenated into a single sequence of length
*total_tokens*. Each query at global position *q* in batch element *b*
can attend to keys in ``[offset[b], q+1)`` — standard causal within its
own sequence, invisible to other sequences.

Args:
offsets: [B+1] cumulative sequence-length offsets.
total_tokens: ``offsets[-1].item()`` — total number of tokens.
padding: FA convention padding on the last dim (default 256).

Returns:
arbitrary_func: [1, 1, 3, total_tokens + padding] int32 tensor.
"""
device = offsets.device
n_func = 3 # F0=0, single interval [F1, F2)

af = torch.zeros(
1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device
)

positions = torch.arange(total_tokens, device=device)
batch_ids = torch.searchsorted(offsets[1:], positions, right=True)
batch_starts = offsets[batch_ids]

# visible(q) = [0, 0) ∪ [batch_start, q+1)
af[0, 0, 1, :total_tokens] = batch_starts.to(torch.int32)
af[0, 0, 2, :total_tokens] = (positions + 1).to(torch.int32)

return af


if __name__ == "__main__":
history_seqlen = torch.tensor([4, 3]).cuda()
max_history_seqlen = 6
Expand Down
Loading