-
Notifications
You must be signed in to change notification settings - Fork 66
Jagged Arbitrary Masked Self Attention support #339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2551fdb
e957c1c
e74a13d
61399ef
74e07c8
528cf77
98e61a3
ad536dd
e3421db
a796b18
c598c13
058fb66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -207,6 +207,120 @@ def padded_target_aware_causal_mask( | |||||||||||
| return ~valid_attn_mask | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def dense_mask_to_arbitrary_func( | ||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is this func used? better to make it a fused triton function for better perf
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This func only used for testing. |
||||||||||||
| valid_mask: torch.Tensor, | ||||||||||||
| seqlen: int, | ||||||||||||
| padding: int = 256, | ||||||||||||
| ) -> torch.Tensor: | ||||||||||||
| """ | ||||||||||||
| Convert a dense bool attention mask to flash_attn's interval-based | ||||||||||||
| arbitrary_func tensor. | ||||||||||||
|
|
||||||||||||
| For each query position q, the arbitrary_func encodes visible key | ||||||||||||
| positions as a union of intervals: | ||||||||||||
| visible(q) = [0, F0) ∪ [F1, F2) ∪ [F3, F4) ∪ ... | ||||||||||||
|
|
||||||||||||
| Args: | ||||||||||||
| valid_mask: [B, N, N] or [B, 1, N, N] bool tensor (True = can attend). | ||||||||||||
| seqlen: sequence length N. | ||||||||||||
| padding: extra padding on last dim (FA convention, default 256). | ||||||||||||
|
|
||||||||||||
| Returns: | ||||||||||||
| arbitrary_func: [B, 1, n_func, seqlen + padding] int32 tensor. | ||||||||||||
| """ | ||||||||||||
| if valid_mask.dim() == 4: | ||||||||||||
| valid_mask = valid_mask.squeeze(1) | ||||||||||||
| assert valid_mask.dim() == 3, f"Expected [B, N, N], got {valid_mask.shape}" | ||||||||||||
|
|
||||||||||||
| B, N, _ = valid_mask.shape | ||||||||||||
| device = valid_mask.device | ||||||||||||
|
|
||||||||||||
| # Detect interval boundaries via transitions | ||||||||||||
| shifted = torch.zeros_like(valid_mask) | ||||||||||||
| shifted[:, :, 1:] = valid_mask[:, :, :-1] | ||||||||||||
| starts = valid_mask & ~shifted # start of each True run | ||||||||||||
| max_intervals = int(starts.sum(dim=-1).max().item()) | ||||||||||||
| n_func = max(2 * max_intervals - 1, 1) | ||||||||||||
| if n_func % 2 == 0: | ||||||||||||
| n_func += 1 | ||||||||||||
|
Comment on lines
+242
to
+245
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Lines 243-245 compute
Suggested change
The surviving formula ( |
||||||||||||
|
|
||||||||||||
| # When first interval doesn't start at 0, it needs an extra slot. | ||||||||||||
| # Recount: base interval [0, F0) is free only if first run starts at 0. | ||||||||||||
| # Worst case: all intervals need explicit [F_start, F_end) pairs. | ||||||||||||
| # n_func = 2*max_intervals + 1 covers all cases. | ||||||||||||
| n_func = 2 * max_intervals + 1 | ||||||||||||
| if n_func % 2 == 0: | ||||||||||||
| n_func += 1 | ||||||||||||
|
|
||||||||||||
| af = torch.zeros(B, 1, n_func, seqlen + padding, dtype=torch.int32, device=device) | ||||||||||||
|
|
||||||||||||
| ends_shifted = torch.zeros_like(valid_mask) | ||||||||||||
| ends_shifted[:, :, :-1] = valid_mask[:, :, 1:] | ||||||||||||
| ends = valid_mask & ~ends_shifted # last True position in each run | ||||||||||||
|
|
||||||||||||
| for b in range(B): | ||||||||||||
| for q in range(N): | ||||||||||||
| row = valid_mask[b, q] | ||||||||||||
| if not row.any(): | ||||||||||||
| continue | ||||||||||||
| start_pos = starts[b, q].nonzero(as_tuple=False).squeeze(-1) | ||||||||||||
| end_pos = ends[b, q].nonzero(as_tuple=False).squeeze(-1) + 1 | ||||||||||||
|
|
||||||||||||
| # F0 encodes [0, F0). If first interval starts at 0, use F0. | ||||||||||||
| # Otherwise F0 stays 0 (empty base interval) and all intervals | ||||||||||||
| # go into the extra slots. | ||||||||||||
| extra_idx = 0 | ||||||||||||
| for iv in range(len(start_pos)): | ||||||||||||
| s, e = start_pos[iv].item(), end_pos[iv].item() | ||||||||||||
| if iv == 0 and s == 0: | ||||||||||||
| af[b, 0, 0, q] = e | ||||||||||||
| else: | ||||||||||||
| af[b, 0, 2 * extra_idx + 1, q] = s | ||||||||||||
| af[b, 0, 2 * extra_idx + 2, q] = e | ||||||||||||
| extra_idx += 1 | ||||||||||||
|
|
||||||||||||
| return af | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def build_jagged_causal_arbitrary_func( | ||||||||||||
| offsets: torch.Tensor, | ||||||||||||
| total_tokens: int, | ||||||||||||
| padding: int = 256, | ||||||||||||
| ) -> torch.Tensor: | ||||||||||||
| """ | ||||||||||||
| Build arbitrary_func for flattened jagged causal attention (B=1). | ||||||||||||
|
|
||||||||||||
| All batch sequences are concatenated into a single sequence of length | ||||||||||||
| *total_tokens*. Each query at global position *q* in batch element *b* | ||||||||||||
| can attend to keys in ``[offset[b], q+1)`` — standard causal within its | ||||||||||||
| own sequence, invisible to other sequences. | ||||||||||||
|
|
||||||||||||
| Args: | ||||||||||||
| offsets: [B+1] cumulative sequence-length offsets. | ||||||||||||
| total_tokens: ``offsets[-1].item()`` — total number of tokens. | ||||||||||||
| padding: FA convention padding on the last dim (default 256). | ||||||||||||
|
|
||||||||||||
| Returns: | ||||||||||||
| arbitrary_func: [1, 1, 3, total_tokens + padding] int32 tensor. | ||||||||||||
| """ | ||||||||||||
| device = offsets.device | ||||||||||||
| n_func = 3 # F0=0, single interval [F1, F2) | ||||||||||||
|
|
||||||||||||
| af = torch.zeros( | ||||||||||||
| 1, 1, n_func, total_tokens + padding, dtype=torch.int32, device=device | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| positions = torch.arange(total_tokens, device=device) | ||||||||||||
| batch_ids = torch.searchsorted(offsets[1:], positions, right=True) | ||||||||||||
| batch_starts = offsets[batch_ids] | ||||||||||||
|
|
||||||||||||
| # visible(q) = [0, 0) ∪ [batch_start, q+1) | ||||||||||||
| af[0, 0, 1, :total_tokens] = batch_starts.to(torch.int32) | ||||||||||||
| af[0, 0, 2, :total_tokens] = (positions + 1).to(torch.int32) | ||||||||||||
|
|
||||||||||||
| return af | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| if __name__ == "__main__": | ||||||||||||
| history_seqlen = torch.tensor([4, 3]).cuda() | ||||||||||||
| max_history_seqlen = 6 | ||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.