|
| 1 | +from typing import List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | + |
| 6 | + |
| 7 | +def maybe_contiguous(x): |
| 8 | + return x.contiguous() if x is not None and x.stride(-1) != 1 else x |
| 9 | + |
| 10 | + |
| 11 | +def sparse_attn_func( |
| 12 | + q, |
| 13 | + k, |
| 14 | + v, |
| 15 | + block_count, |
| 16 | + block_offset, |
| 17 | + column_count, |
| 18 | + column_index, |
| 19 | + dropout_p=0.0, |
| 20 | + softmax_scale=None, |
| 21 | + causal=False, |
| 22 | + softcap=0.0, # 0.0 means deactivated |
| 23 | + alibi_slopes=None, |
| 24 | + deterministic=False, |
| 25 | + return_attn_probs=False, |
| 26 | + *, |
| 27 | + return_softmax_lse=False, |
| 28 | + out=None, |
| 29 | +): |
| 30 | + """Compute attention with vertical and slash sparsity patterns. |
| 31 | + Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: |
| 32 | + block_count and block_offset for slash sparsity patterns, and |
| 33 | + column_count and column_index for vertical sparsity patterns. |
| 34 | + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. |
| 35 | +
|
| 36 | + Arguments: |
| 37 | + q: (batch_size, seqlen, nheads, headdim) |
| 38 | + k: (batch_size, seqlen, nheads_k, headdim) |
| 39 | + v: (batch_size, seqlen, nheads_k, headdim) |
| 40 | + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| 41 | + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) |
| 42 | + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| 43 | + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) |
| 44 | + dropout_p: float. Dropout probability. |
| 45 | + softmax_scale: float. The scaling of QK^T before applying softmax. |
| 46 | + Default to 1 / sqrt(headdim). |
| 47 | + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| 48 | + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| 49 | + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| 50 | + is added to the attention score of query i and key j. |
| 51 | + deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| 52 | + which is slightly slower and uses more memory. The forward pass is always deterministic. |
| 53 | + return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| 54 | + testing only. The returned probabilities are not guaranteed to be correct |
| 55 | + (they might not have the right scaling). |
| 56 | + Return: |
| 57 | + out: (batch_size, seqlen, nheads, headdim). |
| 58 | + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The |
| 59 | + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| 60 | + normalization factor). |
| 61 | + """ |
| 62 | + if softmax_scale is None: |
| 63 | + softmax_scale = q.shape[-1] ** (-0.5) |
| 64 | + |
| 65 | + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] |
| 66 | + out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default( |
| 67 | + q, |
| 68 | + k, |
| 69 | + v, |
| 70 | + block_count, |
| 71 | + block_offset, |
| 72 | + column_count, |
| 73 | + column_index, |
| 74 | + out, |
| 75 | + alibi_slopes, |
| 76 | + dropout_p, |
| 77 | + softmax_scale, |
| 78 | + causal, |
| 79 | + softcap, |
| 80 | + return_attn_probs and dropout_p > 0, |
| 81 | + None, |
| 82 | + ) |
| 83 | + return (out, softmax_lse) if return_softmax_lse else out |
| 84 | + |
| 85 | + |
| 86 | +def sparse_attn_varlen_func( |
| 87 | + q, |
| 88 | + k, |
| 89 | + v, |
| 90 | + block_count, |
| 91 | + block_offset, |
| 92 | + column_count, |
| 93 | + column_index, |
| 94 | + cu_seqlens_q, |
| 95 | + cu_seqlens_k, |
| 96 | + max_seqlen_q, |
| 97 | + max_seqlen_k, |
| 98 | + dropout_p=0.0, |
| 99 | + softmax_scale=None, |
| 100 | + causal=False, |
| 101 | + softcap=0.0, # 0.0 means deactivated |
| 102 | + alibi_slopes=None, |
| 103 | + deterministic=False, |
| 104 | + return_attn_probs=False, |
| 105 | + *, |
| 106 | + return_softmax_lse=False, |
| 107 | + out=None, |
| 108 | +): |
| 109 | + """Compute attention with vertical and slash sparsity patterns. |
| 110 | + Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: |
| 111 | + block_count and block_offset for slash sparsity patterns, and |
| 112 | + column_count and column_index for vertical sparsity patterns. |
| 113 | + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. |
| 114 | +
|
| 115 | + Arguments: |
| 116 | + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. |
| 117 | + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| 118 | + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. |
| 119 | + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| 120 | + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) |
| 121 | + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) |
| 122 | + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) |
| 123 | + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 124 | + of the sequences in the batch, used to index into q. |
| 125 | + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths |
| 126 | + of the sequences in the batch, used to index into kv. |
| 127 | + max_seqlen_q: int. Maximum query sequence length in the batch. |
| 128 | + max_seqlen_k: int. Maximum key sequence length in the batch. |
| 129 | + dropout_p: float. Dropout probability. |
| 130 | + softmax_scale: float. The scaling of QK^T before applying softmax. |
| 131 | + Default to 1 / sqrt(headdim). |
| 132 | + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). |
| 133 | + softcap: float. Anything > 0 activates softcapping attention. |
| 134 | + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of |
| 135 | + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) |
| 136 | + is added to the attention score of query i and key j. |
| 137 | + deterministic: bool. Whether to use the deterministic implementation of the backward pass, |
| 138 | + which is slightly slower and uses more memory. The forward pass is always deterministic. |
| 139 | + return_attn_probs: bool. Whether to return the attention probabilities. This option is for |
| 140 | + testing only. The returned probabilities are not guaranteed to be correct |
| 141 | + (they might not have the right scaling). |
| 142 | + Return: |
| 143 | + out: (total, nheads, headdim). |
| 144 | + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The |
| 145 | + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax |
| 146 | + normalization factor). |
| 147 | + """ |
| 148 | + if softmax_scale is None: |
| 149 | + softmax_scale = q.shape[-1] ** (-0.5) |
| 150 | + |
| 151 | + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] |
| 152 | + out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default( |
| 153 | + q, |
| 154 | + k, |
| 155 | + v, |
| 156 | + block_count, |
| 157 | + block_offset, |
| 158 | + column_count, |
| 159 | + column_index, |
| 160 | + out, |
| 161 | + cu_seqlens_q, |
| 162 | + cu_seqlens_k, |
| 163 | + None, |
| 164 | + alibi_slopes, |
| 165 | + max_seqlen_q, |
| 166 | + max_seqlen_k, |
| 167 | + dropout_p, |
| 168 | + softmax_scale, |
| 169 | + False, |
| 170 | + causal, |
| 171 | + softcap, |
| 172 | + return_attn_probs and dropout_p > 0, |
| 173 | + None, |
| 174 | + ) |
| 175 | + return (out, softmax_lse) if return_softmax_lse else out |
0 commit comments