From de3ac6b3177edfe7534bbb8af504a321685d4ec0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Sun, 19 May 2024 23:48:06 -0700 Subject: [PATCH 01/19] Add pagedattention kernel for CPU --- torchao/csrc/cpu/paged_attention_kernel.cpp | 506 ++++++++++++++++++++ torchao/csrc/paged_attention.cpp | 11 + 2 files changed, 517 insertions(+) create mode 100644 torchao/csrc/cpu/paged_attention_kernel.cpp create mode 100644 torchao/csrc/paged_attention.cpp diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp new file mode 100644 index 0000000000..d4e099ebf2 --- /dev/null +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -0,0 +1,506 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +template +void reduce_head( + const scalar_t* q_ptr_start, + const scalar_t* k_cache_start, + accum_t* attn_w_pos, + int64_t head_size) { + attn_w_pos[0] = 0; + for (long i = 0; i < head_size; i++) { + attn_w_pos[0] += q_ptr_start[i] * k_cache_start[i]; + } +} + +//BF16 +template <> +void reduce_head( + const at::BFloat16* q_ptr_start, + const at::BFloat16* k_cache_start, + float* attn_w_pos, + int64_t head_size) { + attn_w_pos[0] = 0; + using lpVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + auto vec_size = lpVec::size(); + auto vec_tmp_sum = fVec(0.0f); + for (long i = 0; i < vec_size * (head_size / vec_size); i += vec_size) { + auto tmpq = lpVec::loadu(q_ptr_start + i); + auto tmpk = lpVec::loadu(k_cache_start + i); + fVec tmpq1, tmpq2, tmpk1, tmpk2; + //convert to float + std::tie(tmpq1, tmpq2) = at::vec::convert_to_float(tmpq); + std::tie(tmpk1, tmpk2) = at::vec::convert_to_float(tmpk); + vec_tmp_sum = vec_tmp_sum + tmpq1 * tmpk1 + tmpq2 * tmpk2; + } + attn_w_pos[0] = at::vec::vec_reduce_all<>( + [](fVec& x, fVec& y) { + return x + y; + }, + vec_tmp_sum); +} + +template +inline void mul_attenion_weights_and_value_of_head( + const accum_t& attn_w, + const scalar_t* v_cache_start, + accum_t* attn_out_start, + int64_t head_size, + bool accumulated) { + for (auto hsi = 0; hsi < head_size; hsi++) { + if (accumulated) { + attn_out_start[hsi] += attn_w * (float)v_cache_start[hsi]; + } else { + attn_out_start[hsi] = attn_w * (float)v_cache_start[hsi]; + } + } +} + +template <> +inline void mul_attenion_weights_and_value_of_head( + const float& attn_w, + const at::BFloat16* v_cache_start, + float* attn_out_start, + int64_t head_size, + bool accumulated) { + using lpVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + auto lpVec_size = lpVec::size(); + auto fVec_size = fVec::size(); + auto vec_attn_w = fVec(attn_w); + auto vec_tmp_sum = fVec(0.0f); + long i = 0; + for (; i < lpVec_size *(head_size/lpVec_size) ; i += lpVec_size) { + auto tmpv = lpVec::loadu(v_cache_start + i); + fVec tmpv1, tmpv2; + //convert to float + std::tie(tmpv1, tmpv2) = at::vec::convert_to_float(tmpv); + auto tmp1 = tmpv1 * vec_attn_w; + auto tmp2 = tmpv2 * vec_attn_w; + if (accumulated) { + tmp1 = fVec::loadu(attn_out_start + i) + tmp1; + tmp1.store(attn_out_start + i); + tmp2 = fVec::loadu(attn_out_start + i + fVec_size) + tmp2; + tmp2.store(attn_out_start + i + fVec_size); + } else { + tmp1.store(attn_out_start + i); + tmp2.store(attn_out_start + i + fVec_size); + } + } + for (; i < head_size; i++) { + if (accumulated) { + attn_out_start[i] += attn_w * (float)v_cache_start[i]; + } else { + attn_out_start[i] = attn_w * (float)v_cache_start[i]; + } + } +} + +template +inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + Vec data_vec = Vec(val); + int64_t d = 0; + for (; d < size - (size % Vec::size()); d += Vec::size()) { + data_vec.store(data + d); + } +#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) +#pragma unroll +#endif + for (; d < size; d++) { + data[d] = val; + } +} + +template +inline void _mul_div_add_softmax( + const scalar_t* a, + const scalar_t& scale, + const float* mask, + const int& size, + scalar_t* out) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_scale = at::vec::Vectorized(scale); + auto tmp_max = -std::numeric_limits::infinity(); + auto vec_tmp_max = at::vec::Vectorized(tmp_max); + long i = 0; + // max(a * scale + mask) + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 * vec_scale; + if (mask != nullptr) { + auto tmp_mask = at::vec::Vectorized::loadu(mask + i); + tmp1 = tmp1 + tmp_mask; + } + vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); + tmp1.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 * scale; + if (mask != nullptr) { + tmp1 = tmp1 + mask[i]; + } + tmp_max = std::max(tmp_max, tmp1); + out[i] = tmp1; + } + auto max = std::max( + tmp_max, + at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, + at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max)); + // exp and sum + scalar_t sum = 0; + auto max_vec = at::vec::Vectorized(max); + i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 - max_vec; + tmp1 = tmp1.exp_u20(); + sum += at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + tmp1); + tmp1.store(out + i); + } + for (; i < size; i++) { + auto tmp0 = out[i]; + auto tmp1 = std::exp(tmp0 - max); + sum += tmp1; + out[i] = tmp1; + } + auto scale_vec = at::vec::Vectorized(1.0f / sum); + // normalize + i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(out + i); + auto tmp1 = tmp0 * scale_vec; + tmp1.store(out + i); + } + for (; i < size; i++) { + out[i] = out[i] * (1.0f / sum); + } +} + +void reshape_attn_mask_to_4d( + at::Tensor& attn_mask, + int64_t batchSize, + int64_t num_head, + int64_t qSize, + int64_t kvSize) { + // Support mask shapes: + // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) + // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) + // Guaranteed in check_attn_mask_shape + int64_t attn_mask_size_0 = 1; + int64_t attn_mask_size_1 = 1; + if (attn_mask.dim() == 4) { + if (attn_mask.size(0) == batchSize) { + attn_mask_size_0 = batchSize; + } + if (attn_mask.size(1) == num_head) { + attn_mask_size_1 = num_head; + } + } + attn_mask = attn_mask + .view( + {attn_mask_size_0, + attn_mask_size_1, + attn_mask.size(-2), + attn_mask.size(-1)}) + .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); +} + +/** + * Performs scale-dot-product for the next token based on cached key-value + * attention. + * + * This function computes the attention weights and applies the attention + * mechanism to obtain the final output. It takes in tensors representing the + * query, key cache, value cache, head mapping, scale, block tables, context + * lengths, block size, max context length, and optional alibi slopes. The + * output tensor is updated with the computed attention values. + * + * @param out Output tensor [num_seqs, 1, num_heads, head_size]. + * @param query Query tensor [num_seqs, 1, num_heads, head_size]. + * @param key_cache The pre-allocated buffer to store the key cache. The + * shape should be [num_blocks, block_size, num_heads, head_size]. + * @param value_cache The pre-allocated buffer to store the value cache. The + * shape should be [num_blocks, block_size, num_heads, head_size]. + * @param head_mapping Head mapping tensor [num_heads]. The mapping from the + * query head to the kv head to support GQA/MQA. The shape should be the number + * of query heads. + * @param scale Scaling factor for attention weights. In general, it is: + * float(1.0 / (head_size ** 0.5)). + * @param block_tables Block tables tensor [num_seqs, max_num_blocks_per_seq]. + * @param context_lens Context lengths tensor [num_seqs]. + * @param block_size The block size which means the number of token in every + * block. + * @param max_context_len Maximum context length. + * @param attn_mask Optional tensor of alibi slopes with the shape of + * (num_heads). + */ +template +void paged_attention_kernel( + at::Tensor& out, + at::Tensor& query, + at::Tensor& key_cache, + at::Tensor& value_cache, + at::Tensor& head_mapping, + const double scale, + at::Tensor& block_tables, + at::Tensor& context_lens, + int64_t block_size, + c10::optional attn_mask) { + using accum_t = at::opmath_type; + using Vec = at::vec::Vectorized; + const auto dtype = query.scalar_type(); + const auto accumulate_dtype = at::toOpMathType(dtype); + + auto num_seqs = query.size(0); + auto query_size = query.size(1); + auto num_heads = query.size(2); + auto head_size = query.size(3); + auto num_kv_heads = key_cache.size(2); + auto max_num_blocks_per_seq = block_tables.size(1); + auto max_context_len = context_lens.max().item(); + + bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + if (has_attn_mask) { + attn_mask.value() = attn_mask.value().to(at::kFloat); + reshape_attn_mask_to_4d( + attn_mask.value(), num_seqs, num_heads, query_size, attn_mask.value().size(-1)); + } + + auto attn_weights = at::empty( + {num_seqs, num_heads, max_context_len}, + query.options().dtype(accumulate_dtype)); + + auto out_ptr = out.data_ptr(); + auto query_ptr = query.data_ptr(); + auto key_cache_ptr = key_cache.data_ptr(); + auto value_cache_ptr = value_cache.data_ptr(); + auto head_mapping_ptr = head_mapping.data_ptr(); + auto block_tables_ptr = block_tables.data_ptr(); + auto context_lens_ptr = context_lens.data_ptr(); + auto attn_mask_ptr = + attn_mask.has_value() ? attn_mask.value().data_ptr() : nullptr; + + auto attn_weights_ptr = attn_weights.data_ptr(); + auto kv_block_strideB = key_cache.stride(0); + auto q_stride = query.stride(0); + auto attn_weights_strideB = attn_weights.stride(0); + int64_t mStrideB = (has_attn_mask && attn_mask.value().size(0) > 1) + ? attn_mask.value().stride(0) + : 0; + int64_t mStrideH = (has_attn_mask && attn_mask.value().size(1) > 1) + ? attn_mask.value().stride(1) + : 0; + int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; + +#pragma omp parallel for collapse(3) +for (auto token_id = 0; token_id < max_context_len; token_id++) { + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + auto context_len = context_lens_ptr[seq_id]; + if (token_id >= context_len) + continue; + auto attn_w_pos = attn_weights_ptr + seq_id * attn_weights_strideB + + head_id * max_context_len + token_id; + auto q_ptr_start = query_ptr + seq_id * q_stride + head_id * head_size; + auto block_id = block_tables_ptr + [seq_id * max_num_blocks_per_seq + token_id / block_size]; + auto block_offset = token_id % block_size; + auto k_cache_start = key_cache_ptr + block_id * kv_block_strideB + + block_offset * num_kv_heads * head_size + + head_mapping_ptr[head_id] * head_size; + reduce_head( + q_ptr_start, k_cache_start, attn_w_pos, head_size); + } + } + } + +#pragma omp parallel for collapse(2) + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + auto max_val = -10000.0f; + float sum = 0.0f; + auto qid = 0; + auto context_len = context_lens_ptr[seq_id]; + auto attn_w_start = attn_weights_ptr + seq_id * attn_weights_strideB + + head_id * max_context_len; + auto head_mask_start = has_attn_mask + ? attn_mask_ptr + mStrideB * seq_id + mStrideH * head_id + : nullptr; + _mul_div_add_softmax( + attn_w_start, scale, head_mask_start, context_len, attn_w_start); + } + } + // std::cout << "after softmax attn_weights\n"; + // std::cout << attn_weights << "\n"; + auto thread_numbers = omp_get_max_threads(); + auto private_attn_outs = at::empty( + {thread_numbers, num_seqs, num_heads, head_size}, accumulate_dtype); + auto private_attn_out_flag = + at::zeros({thread_numbers, num_seqs, num_heads}, at::kByte); + + auto flag_access = private_attn_out_flag.accessor(); + auto private_attn_out_ptr = private_attn_outs.data_ptr(); + auto private_attn_out_stride = private_attn_outs.stride(0); +// mul and accumulate +#pragma omp parallel for collapse(3) +for (auto token_id = 0; token_id < max_context_len; token_id++) { + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + auto context_len = context_lens_ptr[seq_id]; + auto thread_id = omp_get_thread_num(); + if (token_id >= context_len) + continue; + auto attn_w = attn_weights_ptr + [seq_id * attn_weights_strideB + head_id * max_context_len + + token_id]; + auto block_id = block_tables_ptr + [seq_id * max_num_blocks_per_seq + token_id / block_size]; + auto block_offset = token_id % block_size; + auto v_cache_start = value_cache_ptr + block_id * kv_block_strideB + + block_offset * num_kv_heads * head_size + + head_mapping_ptr[head_id] * head_size; + auto attn_out_start = private_attn_out_ptr + + thread_id * private_attn_out_stride + seq_id * q_stride + + head_id * head_size; + mul_attenion_weights_and_value_of_head( + attn_w, + v_cache_start, + attn_out_start, + head_size, + flag_access[thread_id][seq_id][head_id]); + if (flag_access[thread_id][seq_id][head_id] == 0) { + flag_access[thread_id][seq_id][head_id] = 1; + } + } // for token_id + } // for head_id + } // for seq_id + { + RECORD_FUNCTION( + "ipex::paged_attention::reduction_private_result", + c10::ArrayRef({})); +#pragma omp parallel for collapse(2) + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto hi = 0; hi < num_heads; hi++) { + auto thr0_head_start = + private_attn_out_ptr + (seq_id * num_heads + hi) * head_size; + if (flag_access[0][seq_id][hi] == 0) { + fill_stub(thr0_head_start, 0.0f, head_size); + } + for (auto thread_id = 1; thread_id < thread_numbers; thread_id++) { + if (flag_access[thread_id][seq_id][hi] == 0) { + continue; + } + auto attn_out_head_stride = thread_id * private_attn_out_stride + + (seq_id * num_heads + hi) * head_size; + auto private_attn_out_start = + private_attn_out_ptr + attn_out_head_stride; + at::vec::map2( + [](Vec a, Vec b) { return a + b; }, + thr0_head_start, + private_attn_out_start, + thr0_head_start, + head_size); + } + auto out_start = out_ptr + (seq_id * num_heads + hi) * head_size; + at::vec::map( + [](Vec a) { return a; }, out_start, thr0_head_start, head_size); + } + } + } + +} // paged_attention_kernel + +void paged_attention_kernel_impl( + at::Tensor& out, // [num_seqs, 1, num_heads, head_size] + at::Tensor& query, // [num_seqs, 1, num_heads, head_size] + at::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + at::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] + at::Tensor& head_mapping, // [num_heads] + const double scale, + at::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + at::Tensor& context_lens, // [num_seqs] + int64_t block_size, + c10::optional attn_mask) { + TORCH_CHECK( + query.size(1) == 1, + "Paged attention: only seqlen 1 is supported for query"); + TORCH_CHECK( + query.scalar_type() == key_cache.scalar_type() && + query.scalar_type() == value_cache.scalar_type(), + "Paged attention: Q/K/V should have the same data type"); + TORCH_CHECK( + !attn_mask.has_value() || + query.scalar_type() == attn_mask.value().scalar_type() || + attn_mask.value().scalar_type() != at::ScalarType::Bool, + "Paged attention: Mask should have the same data type as Q/K/V and should not be Bool"); + TORCH_CHECK( + query.dim() == 4 && key_cache.dim() == 4 && value_cache.dim() == 4, + "Paged attention: Accept only 4 dims inputs shape of {B, H, T, K}"); + TORCH_CHECK( + (query.stride(-1) == 1) && (key_cache.stride(-1) == 1) && + (value_cache.stride(-1) == 1) && + (!attn_mask.has_value() || attn_mask.value().stride(-1) == 1), + "Paged attention: Q/KV cache/Mask should be continuous on the last dim"); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::kBFloat16, + at::kHalf, + query.scalar_type(), + "paged_attention", + [&] { + paged_attention_kernel( + out, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + attn_mask); + }); +} + + +} // namespace +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::paged_attention", &paged_attention_kernel_impl); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/csrc/paged_attention.cpp b/torchao/csrc/paged_attention.cpp new file mode 100644 index 0000000000..987c3a55fa --- /dev/null +++ b/torchao/csrc/paged_attention.cpp @@ -0,0 +1,11 @@ +#include +#include +#include + +TORCH_LIBRARY_FRAGMENT(torchao, m) { + m.impl_abstract_pystub("torchao.ops"); + m.def( + "paged_attention(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key_cache, Tensor (a!)value_cache,\ + Tensor(a!) head_mapping, float scale, Tensor(a!) block_tables, Tensor(a!) context_lens, int block_size,\ + Tensor? attn_mask)-> ()"); +} \ No newline at end of file From 2b843e637d2e0a3c43e5385bc3c2581c1caeb3e6 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 20 May 2024 00:05:22 -0700 Subject: [PATCH 02/19] Add ut --- test/kernel/test_paged_attention.py | 221 ++++++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 test/kernel/test_paged_attention.py diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py new file mode 100644 index 0000000000..ac5a6d239e --- /dev/null +++ b/test/kernel/test_paged_attention.py @@ -0,0 +1,221 @@ +import torch +import unittest +import random +from typing import List, Optional, Tuple +from itertools import product +import torchao + + +class PagedAttentionTest(unittest.TestCase): + def create_kv_caches( + self, + num_blocks: int, + block_size: int, + num_layer: int, + num_head: int, + head_size: int, + dtype: torch.dtype, + seed: int, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.manual_seed(seed) + + scale = head_size**-0.5 + key_cache_shape = (num_blocks, block_size, num_head, head_size) + key_caches = [] + for _ in range(num_layer): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype) + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, block_size, num_head, head_size) + value_caches = [] + for _ in range(num_layer): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype) + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches + + def ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights * scale + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + def ref_single_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor], + ) -> None: + num_query_heads = query.shape[1] + num_kv_head = value_cache.shape[2] + head_size = value_cache.shape[3] + block_size = value_cache.shape[1] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + k = k.reshape(num_kv_head, head_size) + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + # out = self.ref_masked_attention(q, keys, values, scale, attn_mask[i]) + out = self.ref_masked_attention(q, keys, values, scale, None) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + def _test_paged_attention_func( + self, + num_seqs: int, + num_head: Tuple[int, int], + head_size: int, + num_blocks: int, + block_size: int, + dtype: torch.dtype, + seed: int, + ) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.manual_seed(seed) + max_seq_len = 512 + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_head = num_head + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" + ) + query.uniform_(-scale, scale) + assert num_query_heads % num_kv_head == 0 + num_queries_per_kv = num_query_heads // num_kv_head + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), + num_queries_per_kv, + ) + attn_mask = None + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max_seq_len # max(context_lens) + attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) + for i in range(num_seqs): + attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) + paded_context_lens = torch.tensor( + [max_context_len for _ in range(num_seqs)] + ).to(torch.int32) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") + + # Create the block tables.NUM_PREFILL_SEQS + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") + + # Create the KV caches. + key_caches, value_caches = self.create_kv_caches( + num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + output = torch.empty_like(query) + torch.ops.torchao.paged_attention( + output, + query.unsqueeze(1), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + paded_context_lens, + block_size, + attn_mask, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + self.ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + attn_mask, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) + + def test_paged_attention(self): + num_blocks = 128 + dtypes = [torch.bfloat16, torch.float, torch.float16] + num_gen_seqs = [2] # Arbitrary values for testing + num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing + head_sizes = [64, 256] + block_sizes = [16, 32] + seeds = [0] + for ( + num_seqs, + num_head, + head_size, + block_size, + dtype, + seed, + ) in product( + num_gen_seqs, + num_heads, + head_sizes, + block_sizes, + dtypes, + seeds, + ): + self._test_paged_attention_func( + num_seqs, + num_head, + head_size, + num_blocks, + block_size, + dtype, + seed, + ) + + +if __name__ == "__main__": + test = unittest.main() \ No newline at end of file From 63a826a60be6dc9a7f3275e4085c620585891b7b Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 20 May 2024 00:44:48 -0700 Subject: [PATCH 03/19] Add UT --- test/kernel/test_fused_kernels.py | 111 ---------------- test/test_ops.py | 213 ++++++++++++++++++++++++++++++ 2 files changed, 213 insertions(+), 111 deletions(-) delete mode 100644 test/kernel/test_fused_kernels.py diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py deleted file mode 100644 index b43abead45..0000000000 --- a/test/kernel/test_fused_kernels.py +++ /dev/null @@ -1,111 +0,0 @@ -import itertools - -import pytest - -# Skip entire test if triton is not available, otherwise CI failure -try: - import triton -except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) - -import torch -from galore_test_utils import get_kernel, make_copy, make_data - -torch.manual_seed(0) -MAX_DIFF_no_tf32 = 1e-5 -MAX_DIFF_tf32 = 1e-3 - - -def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): - # Copy to use for first run -- needed because of autotuning and inplace ops - ( - exp_avg_autotune_copy, - exp_avg2_autotune_copy, - grad_autotune_copy, - proj_matrix_autotune_copy, - params_autotune_copy, - ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) - - # Copy to use for second run to check accuracy - ( - exp_avg_test_copy, - exp_avg2_test_copy, - grad_test_copy, - proj_matrix_test_copy, - params_test_copy, - ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) - - print( - f"Running with {grad.shape[0]} x {grad.shape[1]} grad (param) shape, GaLore orthogonal matrix {list(proj_matrix.shape)}, dtype {grad.dtype} and allow_tf32 {allow_tf32}\n" - f"Kernel: {kernel}", - flush=True, - ) - - ref_op = get_kernel("ref") - test_op = get_kernel(kernel) - - # Reference run - ref_out = ref_op( - grad, - proj_matrix, - exp_avg, - exp_avg2, - params, - ) - - # Autotune - _ = test_op( - grad_autotune_copy, - proj_matrix_autotune_copy, - exp_avg_autotune_copy, - exp_avg2_autotune_copy, - params_autotune_copy, - store=False, - allow_tf32=allow_tf32, - ) - - # Accuracy run - test_out = test_op( - grad_test_copy, - proj_matrix_test_copy, - exp_avg_test_copy, - exp_avg2_test_copy, - params_test_copy, - store=True, - allow_tf32=allow_tf32, - ) - print("Accuracy:") - - output_names = [ - "adam state - running grad mean", - "adam state - running grad var", - "params (after update)", - ] - MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 - for name, ref, tt in zip(output_names, ref_out, test_out): - max_diff = (ref - tt).abs().max() - print(f"-> {name}:\n Max err: {max_diff:.6f}") - assert max_diff < MAX_DIFF - - -KERNELS = ["hybrid"] # "fused"] -DTYPES = [torch.float32] # torch.float16 -ROW_DIMS = [4096] -COL_DIMS = [4096] # , 11008] -RANKS = [128] -ALLOW_TF32 = [False] # , True] - -TEST_CONFIGS = list( - itertools.product(KERNELS, DTYPES, ROW_DIMS, COL_DIMS, RANKS, ALLOW_TF32) -) - -# TEST_CONFIGS = TEST_CONFIGS[0:1] - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") -@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) -def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): - torch.backends.cuda.matmul.allow_tf32 = allow_tf32 - - exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) - run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32) diff --git a/test/test_ops.py b/test/test_ops.py index e260e86f0f..5fb28e9b30 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,10 +1,14 @@ import torch from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck +from typing import List, Optional, Tuple import torchao from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 import unittest +import random from parameterized import parameterized +from itertools import product + # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): @@ -134,7 +138,216 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs() assert relative_error.mean() < 1e-2 + + def _create_kv_caches( + self, + num_blocks: int, + block_size: int, + num_layer: int, + num_head: int, + head_size: int, + dtype: torch.dtype, + seed: int, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.manual_seed(seed) + + scale = head_size**-0.5 + key_cache_shape = (num_blocks, block_size, num_head, head_size) + key_caches = [] + for _ in range(num_layer): + key_cache = torch.empty(size=key_cache_shape, dtype=dtype) + key_cache.uniform_(-scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, block_size, num_head, head_size) + value_caches = [] + for _ in range(num_layer): + value_cache = torch.empty(size=value_cache_shape, dtype=dtype) + value_cache.uniform_(-scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches + + def _ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() + attn_weights = attn_weights * scale + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + def _ref_paged_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + scale: float, + ) -> None: + num_query_heads = query.shape[1] + num_kv_head = value_cache.shape[2] + head_size = value_cache.shape[3] + block_size = value_cache.shape[1] + num_seqs = query.shape[0] + + block_tables = block_tables.cpu().tolist() + context_lens = context_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, block_offset, :, :] + k = k.reshape(num_kv_head, head_size) + keys.append(k) + + v = value_cache[block_number, block_offset, :, :] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + # out = self._ref_masked_attention(q, keys, values, scale, attn_mask[i]) + out = self._ref_masked_attention(q, keys, values, scale, None) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + def _test_paged_attention_func( + self, + num_seqs: int, + num_head: Tuple[int, int], + head_size: int, + num_blocks: int, + block_size: int, + dtype: torch.dtype, + seed: int, + ) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.manual_seed(seed) + max_seq_len = 512 + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_head = num_head + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" + ) + query.uniform_(-scale, scale) + assert num_query_heads % num_kv_head == 0 + num_queries_per_kv = num_query_heads // num_kv_head + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), + num_queries_per_kv, + ) + attn_mask = None + context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] + context_lens[-1] = max_seq_len + max_context_len = max_seq_len # max(context_lens) + attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) + for i in range(num_seqs): + attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) + paded_context_lens = torch.tensor( + [max_context_len for _ in range(num_seqs)] + ).to(torch.int32) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") + + # Create the block tables.NUM_PREFILL_SEQS + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") + + # Create the KV caches. + key_caches, value_caches = self._create_kv_caches( + num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + output = torch.empty_like(query) + torch.ops.torchao.paged_attention( + output, + query.unsqueeze(1), + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + paded_context_lens, + block_size, + attn_mask, + ) + + # Run the reference implementation. + ref_output = torch.empty_like(query) + self._ref_paged_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + context_lens, + scale, + ) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_paged_attention(self): + num_blocks = 128 + dtypes = [torch.bfloat16, torch.float, torch.float16] + num_gen_seqs = [2] # Arbitrary values for testing + num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing + head_sizes = [64, 256] + block_sizes = [16, 32] + seeds = [0] + for ( + num_seqs, + num_head, + head_size, + block_size, + dtype, + seed, + ) in product( + num_gen_seqs, + num_heads, + head_sizes, + block_sizes, + dtypes, + seeds, + ): + self._test_paged_attention_func( + num_seqs, + num_head, + head_size, + num_blocks, + block_size, + dtype, + seed, + ) + if __name__ == "__main__": unittest.main() From ab1785967ed1afd0730c9f68eadd03e390830468 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 20 May 2024 00:56:10 -0700 Subject: [PATCH 04/19] Refine code --- torchao/csrc/cpu/paged_attention_kernel.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp index d4e099ebf2..94454e0d2f 100644 --- a/torchao/csrc/cpu/paged_attention_kernel.cpp +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -365,8 +365,6 @@ for (auto token_id = 0; token_id < max_context_len; token_id++) { attn_w_start, scale, head_mask_start, context_len, attn_w_start); } } - // std::cout << "after softmax attn_weights\n"; - // std::cout << attn_weights << "\n"; auto thread_numbers = omp_get_max_threads(); auto private_attn_outs = at::empty( {thread_numbers, num_seqs, num_heads, head_size}, accumulate_dtype); @@ -499,7 +497,7 @@ void paged_attention_kernel_impl( } // namespace -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { +TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::paged_attention", &paged_attention_kernel_impl); } From 784f5032ac047e1e25b3d6f9e75c87d58386c95c Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 20 May 2024 01:13:29 -0700 Subject: [PATCH 05/19] Add PagedAttention KV Cache manager --- torchao/kv_cache.py | 406 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 torchao/kv_cache.py diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py new file mode 100644 index 0000000000..69254f2770 --- /dev/null +++ b/torchao/kv_cache.py @@ -0,0 +1,406 @@ +import torch +import torch.nn as nn +import functools +from typing import List, Tuple, Optional, Dict, Any +import copy + +HANDLED_FUNCTIONS = {} + + +class PagedTensor(object): + def __init__( + self, + cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + ): + self.block_tables = block_tables + self.cache = cache + self.context_lens = context_lens + + def __repr__(self): + return f"PagedTensor({self.cache.shape})" + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func not in HANDLED_FUNCTIONS or not all( + issubclass(t, (torch.Tensor, PagedTensor)) for t in types + ): + return NotImplementedError( + "{} is not supported by PagedTensor".format(func) + ) + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + +def implements(torch_function): + """Register a torch function override for PagedTensor""" + + def decorator(func): + functools.update_wrapper(func, torch_function) + HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +@implements(torch.nn.functional.scaled_dot_product_attention) +def scaled_dot_product_attention( + input, key_tensor, value_tensor, attn_mask=None, scale=None +): + query = input + key_cache = key_tensor.cache + value_cache = value_tensor.cache + num_kv_head = key_cache.size(2) + num_queries_per_kv = query.size(2) // num_kv_head + block_size = key_cache.size(1) + block_tables = key_tensor.block_tables + head_mapping = torch.repeat_interleave( + torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), num_queries_per_kv + ) + context_lens = key_tensor.context_lens + output = torch.empty_like(query) + torch.ops.torchao.paged_attention( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + attn_mask, + ) + return output + + +class PagedAttentionCache(object): + def __init__( + self, + num_blocks: int, + block_size: int, + num_key_value_heads: int, + head_dim: int, + num_layers: int, + device="cpu", + dtype=None, + ) -> None: + super().__init__() + + # model info + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.num_layers = num_layers + + # Cache tensor info + self.dtype = dtype if dtype is not None else torch.float32 + self.device = device + self.num_blocks = num_blocks + self.block_size = block_size + + cache_shape = ( + self.num_blocks, + self.block_size, + self.num_key_value_heads, + self.head_dim, + ) + + # KV caches for each layer + self.key_caches = [ + torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(num_layers) + ] + self.value_caches = [ + torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(num_layers) + ] + + self.seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) + + # paged cache runtime information + self.free_blocks = list(range(num_blocks)) # free blocks + self.block_ref_count = [ + 0 + ] * self.num_blocks # init the reference count for each physical block + self.block_tables = dict() # mapping logical block to physical blocks for each sequence + + # The follow two states are shared accross layer but only for the current decode step. Need to update for every decode step. + self.batch2seq = None # mapping batch index to {seq_id0, seq_id1, ...} to enable prompt sharing. + self.slots_mapping = None # mapping logical slots to physical slots. + + def _copy_on_write(self, src_block_idx: int, dst_block_idx: int): + """ + Copy the content of src_block_idx to dst_block_idx. + + Args: + src_block_idx (int): The index of the source block. + dst_block_idx (int): The index of the destination block. + """ + for layer_idx in range(self.num_layers): + self.key_caches[layer_idx][dst_block_idx] = self.key_caches[layer_idx][ + src_block_idx + ].clone() + self.value_caches[layer_idx][dst_block_idx] = self.value_caches[layer_idx][ + src_block_idx + ].clone() + + def allocate(self, batch_size: int, key_len: int) -> None: + """ + Allocate physical slots for a every sequence with key_len tokens in this batcch. + + Args: + - batch_size (int): The batch size of the sequence. + - key_len (int): The length of the key. + + Returns: + - None + """ + self.slots_mapping = [] + past_context_len = self.seen_tokens + if self.batch2seq is None: + self.set_batch2seq_for_prompt_sharing(batch_size, 1) + for i in range(batch_size): + seq_idx = self.batch2seq[i][0] + # Scenario 1: New seqence: allocate blocks for this sequence + if seq_idx not in self.block_tables: + needed_blocks = (key_len + self.block_size - 1) // self.block_size + if needed_blocks > len(self.free_blocks): + raise AssertionError( + f"No space in KV cache to store new token state. needed_blocks: {needed_blocks}, free_blocks: {self.free_blocks}" + ) + blocks = self.free_blocks[:needed_blocks] + self.free_blocks = self.free_blocks[needed_blocks:] + self.block_tables[seq_idx] = blocks + for block_idx in blocks: + self.block_ref_count[block_idx] += 1 + # Senario 2: Partial processed sequence: find free slots in the allocated blocks or allocate new blocks + else: + seq_len = key_len + past_context_len + target_blocks = (seq_len + self.block_size - 1) // self.block_size + new_blocks = target_blocks - len(self.block_tables[seq_idx]) + + if new_blocks > len(self.free_blocks): + raise AssertionError( + f"PagedAttentionCache: No enough free blocks to allocate for sequence {seq_idx}." + ) + + if new_blocks > 0: # allocate new blocks + candidate_blocks = self.free_blocks[:new_blocks] + self.block_tables[seq_idx].extend(self.free_blocks[:new_blocks]) + self.free_blocks = self.free_blocks[new_blocks:] + for block_idx in candidate_blocks: + self.block_ref_count[block_idx] += 1 + else: + last_block = self.block_tables[seq_idx][-1] + # sharing the last block with other sequences, need to allocate a new block and copy the last block + if self.block_ref_count[last_block] > 1: + if len(self.free_blocks) == 0: + raise AssertionError( + f"PagedAttentionCache: No enough free blocks to allocate for sequence {seq_idx}." + ) + new_block = self.free_blocks.pop() + self.block_tables[seq_idx][-1] = new_block + self.block_ref_count[new_block] += 1 + self.block_ref_count[last_block] -= 1 + self._copy_on_write(last_block, new_block) + + slots = [] + # the slots for this sequence + for j in range(key_len): + token_id = j + past_context_len + block_idx = token_id // self.block_size + block_offset = token_id % self.block_size + slots.append( + self.block_tables[seq_idx][block_idx] * self.block_size + + block_offset + ) + self.slots_mapping.append(slots) + self.slots_mapping = torch.tensor( + self.slots_mapping, dtype=torch.long, device=self.device + ) + # step 2): fork new sequences to enable prompt sharing + for batch_idx in range(batch_size): + seq_ids = self.batch2seq[batch_idx] + # fork the blocks allocated for the first sequence to other sequences in the batch + for seq_id in seq_ids[1:]: + self._fork(seq_ids[0], seq_id) + + def _free(self, seq_idx: int): + """ + Frees the blocks allocated for the given sequence index. + + Args: + seq_idx (int): The index of the sequence whose blocks are to be freed. + + Raises: + AssertionError: If the given sequence index is not present in the block tables. + """ + + if seq_idx not in self.block_tables: + raise AssertionError( + f"PagedAttentionCache: Sequence index {seq_idx} is not present in the block tables." + ) + + for block_idx in self.block_tables[seq_idx]: + self.block_ref_count[block_idx] -= 1 + if self.block_ref_count[block_idx] == 0: + self.free_blocks.append(block_idx) + + def _fork(self, seq_idx: int, new_seq_idx: int): + """ + Forks the blocks allocated for seq_idx to new_seq_idx. + + Args: + seq_idx (int): The index of the sequence to be forked. + new_seq_idx (int): The index of the new sequence. + + Raises: + AssertionError: If seq_idx is not in block_tables or if new_seq_idx is already in block_tables. + """ + if seq_idx not in self.block_tables: + raise AssertionError( + f"PagedAttentionCache: Sequence index {seq_idx} is not present in the block tables." + ) + + self.block_tables[new_seq_idx] = copy.deepcopy(self.block_tables[seq_idx]) + for block_idx in self.block_tables[seq_idx]: + self.block_ref_count[block_idx] += 1 + + def set_batch2seq_for_prompt_sharing(self, batch_size: int, beam_size: int): + """ + Set the batch2seq mapping for prompt sharing. + + Args: + batch_size (int): The batch size. + beam_size (int): The beam size. + """ + self.batch2seq = {} + for i in range(batch_size): + self.batch2seq[i] = [i * beam_size + j for j in range(beam_size)] + + def _reshape_and_cache( + self, + slot_mapping: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + ): + """ + Reshapes and caches the key and value states based on the given slot mapping. + + Args: + slot_mapping (List[List[int]]): A list of lists representing the slot mapping. + key_states (torch.Tensor): The key states tensor. + value_states (torch.Tensor): The value states tensor. + layer_idx (int): The index of the layer. + + Returns: + None + """ + if key_states.shape[-2] != self.num_key_value_heads: + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + key_cache = self.key_caches[layer_idx].view( + -1, self.num_key_value_heads, self.head_dim + ) + value_cache = self.value_caches[layer_idx].view( + -1, self.num_key_value_heads, self.head_dim + ) + key_cache[slot_mapping] = key_states + value_cache[slot_mapping] = value_states + + def get_seq_length(self, layer_idx: int = 0) -> int: + return self.seen_tokens + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. PagedAttentionCache does not have a maximum length.""" + RuntimeError("PagedAttentionCache does not have a maximum sequence length.") + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update the cache with key and value states for a specific layer. + + Args: + key_states (torch.Tensor): The new key states tensor of shape [batch, head, seq, dim]. + value_states (torch.Tensor): The new value states tensor of shape [batch, head, seq, dim]. + layer_idx (int): The index of the layer. + cache_kwargs (Dict[str, Any]): Additional arguments for the cache subclass. + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing the updated key states and value states tensors(entire context token states). + + Raises: + AssertionError: If the batch size is inconsistent with the existing cache. + """ + batch_size = key_states.shape[0] # [batch, head, seq, dim] + cur_len = key_states.shape[-2] + + # # slots info for key/value are same for every layer and allocate should be called before model.forward() to reduce the allocation overhead + # AssertionError( + # self.slots_mapping is not None, + # "PagedAttentionCache: Please first call allocate() of this object to get target positions in paged cache before the model.forward().", + # ) + #cache key_states & value_states + self._reshape_and_cache(self.slots_mapping, key_states, value_states, layer_idx) + + if layer_idx == self.num_layers - 1: + self.seen_tokens += cur_len + self.slot_mapping = None + + if ( + self.seen_tokens == 0 + or self.seen_tokens == cur_len + and layer_idx == self.num_layers - 1 + ): # first token + return key_states, value_states + else: # Next token + if layer_idx == self.num_layers - 1: + # last layer already updated self.seen_tokens + context_lens = torch.tensor( + [self.seen_tokens for _ in range(batch_size)], + dtype=torch.int32, + ) + else: + context_lens = torch.tensor( + [self.seen_tokens + cur_len for _ in range(batch_size)], + dtype=torch.int32, + ) + block_tables_t = [] + for seq_idx in range(batch_size): + block_tables_t.append(self.block_tables[seq_idx]) + block_tables_t = torch.tensor( + block_tables_t, dtype=torch.int32, device=self.device + ) + return PagedTensor( + self.key_caches[layer_idx], block_tables_t, context_lens + ), PagedTensor(self.value_caches[layer_idx], block_tables_t, context_lens) + + def reorder_cache(self, beam_idx: torch.Tensor) -> None: + """ + Reorder the cache according to the beam index. The beam index is a tensor of shape (batch_size,) + and the sequence id can be get from the self.batch2seq. + """ + freed_seqs = [] + new_block_tables = self.block_tables.copy() + num_beams = beam_idx.numel() // len(self.batch2seq) + for batch_idx, target_batch_idx in enumerate(beam_idx.tolist()): + target_seq_id = self.batch2seq[target_batch_idx // num_beams][0] + seq_id = self.batch2seq[batch_idx // num_beams][0] + freed_seqs.append(seq_id) + new_block_tables[seq_id] = [] + for block in self.block_tables[target_seq_id]: + self.block_ref_count[block] += 1 + new_block_tables[seq_id].append(block) + for seq_idx in freed_seqs: + self._free(seq_idx) + self.block_tables = new_block_tables + self.batch2seq = None \ No newline at end of file From 57c4faafc32f5222fd6460bea3314f214b47fbcd Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 4 Jun 2024 21:27:14 -0700 Subject: [PATCH 06/19] Enable flash decodeing for paged attention. --- test/kernel/test_paged_attention.py | 576 +++++++++++++------- torchao/csrc/cpu/paged_attention_kernel.cpp | 483 +++++++++------- 2 files changed, 685 insertions(+), 374 deletions(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index ac5a6d239e..6e68ad5516 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -4,218 +4,418 @@ from typing import List, Optional, Tuple from itertools import product import torchao +from torchao.kv_cache import PagedAttentionCache, PagedTensor +class NiaveCache: + def __init__(self): + self.past_key = None + self.past_value = None -class PagedAttentionTest(unittest.TestCase): - def create_kv_caches( - self, - num_blocks: int, - block_size: int, - num_layer: int, - num_head: int, - head_size: int, - dtype: torch.dtype, - seed: int, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - torch.manual_seed(seed) - - scale = head_size**-0.5 - key_cache_shape = (num_blocks, block_size, num_head, head_size) - key_caches = [] - for _ in range(num_layer): - key_cache = torch.empty(size=key_cache_shape, dtype=dtype) - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, block_size, num_head, head_size) - value_caches = [] - for _ in range(num_layer): - value_cache = torch.empty(size=value_cache_shape, dtype=dtype) - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches - - def ref_masked_attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights * scale - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - def ref_single_query_cached_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor], - ) -> None: - num_query_heads = query.shape[1] - num_kv_head = value_cache.shape[2] - head_size = value_cache.shape[3] - block_size = value_cache.shape[1] - num_seqs = query.shape[0] - - block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, block_offset, :, :] - k = k.reshape(num_kv_head, head_size) - keys.append(k) - - v = value_cache[block_number, block_offset, :, :] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - # out = self.ref_masked_attention(q, keys, values, scale, attn_mask[i]) - out = self.ref_masked_attention(q, keys, values, scale, None) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - def _test_paged_attention_func( - self, - num_seqs: int, - num_head: Tuple[int, int], - head_size: int, - num_blocks: int, - block_size: int, - dtype: torch.dtype, - seed: int, - ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.manual_seed(seed) - max_seq_len = 512 - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_head = num_head - query = torch.empty( - num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" + def expand_cache(self, beam_size): + self.past_key = self.past_key.repeat_interleave(beam_size, dim=0) + self.past_value = self.past_value.repeat_interleave(beam_size, dim=0) + + def update(self, key, value, layer_idx=0): + if self.past_key is None: + self.past_key = key + self.past_value = value + else: + self.past_key = torch.cat((self.past_key, key), dim=2) + self.past_value = torch.cat((self.past_value, value), dim=2) + return self.past_key, self.past_value + + def reorder_cache(self, beam_idx): + self.past_key = self.past_key.index_select(0, beam_idx) + self.past_value = self.past_value.index_select(0, beam_idx) + + +class MHAModule(torch.nn.Module): + def __init__(self, head_dim, num_heads, num_kv_heads): + super(MHAModule, self).__init__() + self.head_dim = head_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.scale = head_dim**-0.5 + self.q = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_heads * self.head_dim ) - query.uniform_(-scale, scale) - assert num_query_heads % num_kv_head == 0 - num_queries_per_kv = num_query_heads // num_kv_head - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), - num_queries_per_kv, + self.k = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim ) - attn_mask = None - context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] - context_lens[-1] = max_seq_len - max_context_len = max_seq_len # max(context_lens) - attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) - for i in range(num_seqs): - attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) - paded_context_lens = torch.tensor( - [max_context_len for _ in range(num_seqs)] - ).to(torch.int32) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") - - # Create the block tables.NUM_PREFILL_SEQS - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") - - # Create the KV caches. - key_caches, value_caches = self.create_kv_caches( - num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed + self.v = torch.nn.Linear( + self.num_heads * self.head_dim, self.num_kv_heads * self.head_dim + ) + + def forward(self, inputs, kv_cache): + query = self.q(inputs) + key = self.k(inputs) + value = self.v(inputs) + batch_size = inputs.size(0) + query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose( + 1, 2 ) - key_cache, value_cache = key_caches[0], value_caches[0] - - output = torch.empty_like(query) - torch.ops.torchao.paged_attention( - output, - query.unsqueeze(1), - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - paded_context_lens, + key = key.view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + updated_key, updated_value = kv_cache.update(key, value, 0) + if isinstance(updated_key, torch.Tensor): + updated_key = updated_key.repeat_interleave( + self.num_heads // self.num_kv_heads, dim=1 + ) + updated_value = updated_value.repeat_interleave( + self.num_heads // self.num_kv_heads, dim=1 + ) + output = torch.nn.functional.scaled_dot_product_attention( + query, updated_key, updated_value, scale=self.scale + ) + return output + + +class PagedAttentionCachePagedTensorTest(unittest.TestCase): + def _test_paged_attention_cache( + self, + num_blocks, + block_size, + num_query_heads, + num_key_value_heads, + head_dim, + device, + dtype, + batch_size, + beam_size, + ): + num_layers = 1 + prompt_len = 32 + mha_model = MHAModule(head_dim, num_query_heads, num_key_value_heads).to( + device=device, dtype=dtype + ) + naive_cache = NiaveCache() + pagedcache = PagedAttentionCache( + num_blocks, block_size, - attn_mask, + num_key_value_heads, + head_dim, + num_layers, + device, + dtype, ) + # enable prompt sharing for the first token, fork + pagedcache.set_batch2seq_for_prompt_sharing(batch_size, beam_size) + pagedcache.allocate(batch_size, prompt_len) + prompt_inputs = torch.randn( + batch_size, + prompt_len, + num_query_heads * head_dim, + device=device, + dtype=dtype, + ) + paged_output = mha_model(prompt_inputs, pagedcache) + naive_output = mha_model(prompt_inputs, naive_cache) + torch.allclose(paged_output, naive_output) + + beam_idx = torch.arange( + 0, batch_size * beam_size, beam_size, device=device, dtype=torch.int64 + ).repeat_interleave(beam_size) + naive_cache.expand_cache(beam_size) + naive_cache.reorder_cache(beam_idx) + pagedcache.reorder_cache(beam_idx) - # Run the reference implementation. - ref_output = torch.empty_like(query) - self.ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - attn_mask, + # Next token + pagedcache.allocate(batch_size * beam_size, 1) + next_inputs = torch.randn( + batch_size * beam_size, + 1, + num_query_heads * head_dim, + device=device, + dtype=dtype, ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) - def test_paged_attention(self): + paged_output = mha_model(next_inputs, pagedcache) + naive_output = mha_model(next_inputs, naive_cache) + torch.allclose(paged_output, naive_output, atol=1e-3, rtol=1e-3) + + for i in range(batch_size): + beam_idx[i * beam_size : (i + 1) * beam_size] = torch.randint( + i * beam_size, + (i + 1) * beam_size, + (1, beam_size), + device=device, + dtype=torch.int64, + ) + naive_cache.reorder_cache(beam_idx) + pagedcache.reorder_cache(beam_idx) + + # Next token + pagedcache.allocate(batch_size * beam_size, 1) + prompt_inputs = torch.randn( + batch_size * beam_size, + 1, + num_query_heads * head_dim, + device=device, + dtype=dtype, + ) + paged_output = mha_model(prompt_inputs, pagedcache) + naive_output = mha_model(prompt_inputs, naive_cache) + torch.allclose(paged_output, naive_output, atol=1e-3, rtol=1e-3) + + def test_paged_attention_kv_cache(self): + # num_blocks, block_size, num_query_heads, num_key_value_heads, head_dim, device, dtype, batch_size, beam_size num_blocks = 128 - dtypes = [torch.bfloat16, torch.float, torch.float16] - num_gen_seqs = [2] # Arbitrary values for testing - num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing - head_sizes = [64, 256] block_sizes = [16, 32] - seeds = [0] + num_query_heads = [40] + num_key_value_heads = [40, 10, 1] + head_dim = [64, 128] + device = ["cpu"] + dtypes = [torch.float, torch.float16, torch.bfloat16] + batch_size = [1, 8] + beam_size = [1, 4] for ( - num_seqs, - num_head, - head_size, block_size, + num_query_head, + num_key_value_head, + head_dim, + device, dtype, - seed, + batch_size, + beam_size, ) in product( - num_gen_seqs, - num_heads, - head_sizes, block_sizes, + num_query_heads, + num_key_value_heads, + head_dim, + device, dtypes, - seeds, + batch_size, + beam_size, ): - self._test_paged_attention_func( - num_seqs, - num_head, - head_size, + self._test_paged_attention_cache( num_blocks, block_size, + num_query_head, + num_key_value_head, + head_dim, + device, dtype, - seed, + batch_size, + beam_size, ) +# class PagedAttentionTest(unittest.TestCase): +# def create_kv_caches( +# self, +# num_blocks: int, +# block_size: int, +# num_layer: int, +# num_head: int, +# head_size: int, +# dtype: torch.dtype, +# seed: int, +# ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: +# torch.random.manual_seed(seed) +# torch.manual_seed(seed) + +# scale = head_size**-0.5 +# key_cache_shape = (num_blocks, num_head, block_size, head_size) +# key_caches = [] +# for _ in range(num_layer): +# key_cache = torch.empty(size=key_cache_shape, dtype=dtype) +# key_cache.uniform_(-scale, scale) +# key_caches.append(key_cache) + +# value_cache_shape = (num_blocks, num_head, block_size, head_size) +# value_caches = [] +# for _ in range(num_layer): +# value_cache = torch.empty(size=value_cache_shape, dtype=dtype) +# value_cache.uniform_(-scale, scale) +# value_caches.append(value_cache) +# return key_caches, value_caches + +# def ref_masked_attention( +# self, +# query: torch.Tensor, +# key: torch.Tensor, +# value: torch.Tensor, +# scale: float, +# attn_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() +# attn_weights = attn_weights * scale +# if attn_mask is not None: +# attn_weights = attn_weights + attn_mask.float() +# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) +# out = torch.einsum("hqk,khd->qhd", attn_weights, value) +# return out + +# def ref_single_query_cached_kv_attention( +# self, +# output: torch.Tensor, +# query: torch.Tensor, +# num_queries_per_kv: int, +# key_cache: torch.Tensor, +# value_cache: torch.Tensor, +# block_tables: torch.Tensor, +# context_lens: torch.Tensor, +# scale: float, +# attn_mask: Optional[torch.Tensor], +# ) -> None: +# num_query_heads = query.shape[1] +# num_kv_head = value_cache.shape[1] +# head_size = value_cache.shape[3] +# block_size = value_cache.shape[2] +# num_seqs = query.shape[0] + +# block_tables = block_tables.cpu().tolist() +# context_lens = context_lens.cpu().tolist() +# for i in range(num_seqs): +# q = query[i].unsqueeze(0) +# block_table = block_tables[i] +# context_len = int(context_lens[i]) + +# keys = [] +# values = [] +# for j in range(context_len): +# key = torch.empty( +# num_kv_head, head_size, dtype=query.dtype, device="cpu" +# ) +# value = torch.empty( +# num_kv_head, head_size, dtype=query.dtype, device="cpu" +# ) +# for k in range(num_kv_head): +# block_number = int(block_table[j // block_size]) +# block_offset = j % block_size +# key[k, :] = key_cache[block_number, k, block_offset, :] +# value[k, :] = value_cache[block_number, k, block_offset, :] +# keys.append(key) +# values.append(value) +# keys = torch.stack(keys, dim=0) +# values = torch.stack(values, dim=0) +# if num_queries_per_kv > 1: +# # Handle MQA and GQA +# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) +# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) +# # out = self.ref_masked_attention(q, keys, values, scale, attn_mask[i]) +# out = self.ref_masked_attention(q, keys, values, scale, None) +# out = out.view(num_query_heads, head_size) +# output[i].copy_(out, non_blocking=True) + +# def _test_paged_attention_func( +# self, +# num_seqs: int, +# num_head: Tuple[int, int], +# head_size: int, +# num_blocks: int, +# block_size: int, +# dtype: torch.dtype, +# seed: int, +# ) -> None: +# random.seed(seed) +# torch.random.manual_seed(seed) +# torch.manual_seed(seed) +# max_seq_len = 512 +# scale = float(1.0 / (head_size**0.5)) +# num_query_heads, num_kv_head = num_head +# query = torch.empty( +# num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" +# ) +# query.uniform_(-scale, scale) +# assert num_query_heads % num_kv_head == 0 +# num_queries_per_kv = num_query_heads // num_kv_head +# head_mapping = torch.repeat_interleave( +# torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), +# num_queries_per_kv, +# ) +# attn_mask = None +# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] +# context_lens[-1] = max_seq_len +# max_context_len = max_seq_len # max(context_lens) +# attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) +# for i in range(num_seqs): +# attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) +# paded_context_lens = torch.tensor( +# [max_context_len for _ in range(num_seqs)] +# ).to(torch.int32) +# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") + +# # Create the block tables.NUM_PREFILL_SEQS +# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size +# block_tables = [] +# for _ in range(num_seqs): +# block_table = [ +# random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) +# ] +# block_tables.append(block_table) +# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") + +# # Create the KV caches. +# key_caches, value_caches = self.create_kv_caches( +# num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed +# ) +# key_cache, value_cache = key_caches[0], value_caches[0] +# output = torch.empty_like(query.unsqueeze(2)) +# torch.ops.torchao.paged_attention( +# output, +# query.unsqueeze(2), +# key_cache, +# value_cache, +# head_mapping, +# scale, +# block_tables, +# paded_context_lens, +# block_size, +# attn_mask, +# ) +# output = output.squeeze(2) +# #Run the reference implementation. +# ref_output = torch.empty_like(query) +# self.ref_single_query_cached_kv_attention( +# ref_output, +# query, +# num_queries_per_kv, +# key_cache, +# value_cache, +# block_tables, +# context_lens, +# scale, +# attn_mask, +# ) +# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) + +# def test_paged_attention(self): +# num_blocks = 128 +# dtypes = [torch.bfloat16, torch.float, torch.float16] +# num_gen_seqs = [2] # Arbitrary values for testing +# num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing +# head_sizes = [64, 256] +# block_sizes = [16, 32] +# seeds = [0] +# for ( +# num_seqs, +# num_head, +# head_size, +# block_size, +# dtype, +# seed, +# ) in product( +# num_gen_seqs, +# num_heads, +# head_sizes, +# block_sizes, +# dtypes, +# seeds, +# ): +# pass +# self._test_paged_attention_func( +# num_seqs, +# num_head, +# head_size, +# num_blocks, +# block_size, +# dtype, +# seed, +# ) + + if __name__ == "__main__": test = unittest.main() \ No newline at end of file diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp index 94454e0d2f..5aa7686b66 100644 --- a/torchao/csrc/cpu/paged_attention_kernel.cpp +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -21,6 +21,8 @@ #include #include +#define PARTITION_SIZE 256 + namespace torchao { namespace { @@ -121,94 +123,97 @@ inline void mul_attenion_weights_and_value_of_head( } } -template -inline void fill_stub(scalar_t* data, scalar_t val, int64_t size) { - using Vec = at::vec::Vectorized; - Vec data_vec = Vec(val); - int64_t d = 0; - for (; d < size - (size % Vec::size()); d += Vec::size()) { - data_vec.store(data + d); +// out = val * a + b +template +inline void _scale_attn_mask_fusion_kernel( + T1* a, + float* b, + const int& size, + T2* out, + float val) { + const auto vec_size = at::vec::Vectorized::size(); + const auto vec_scale = at::vec::Vectorized(val); + int64_t i = 0; + for (; i < size - (size % vec_size); i += vec_size) { + auto a_v = at::vec::Vectorized::loadu(a + i); + auto b_v = at::vec::Vectorized::loadu(b + i); + auto res = a_v * vec_scale + b_v; + res.store(out + i); } -#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) -#pragma unroll -#endif - for (; d < size; d++) { - data[d] = val; + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = b[i]; + out[i] = tmp0 * val + tmp1; + } +} + +// 1) out = exp(a - val) +// 2) val = sum(out) +template +inline void _exp_reduce_sum_fusion_kernel( + T1* a, + const int& size, + T2* out, + T1& val) { + auto vec_size = at::vec::Vectorized::size(); + auto vec_max = at::vec::Vectorized(val); + T1 tmp_sum = 0; + auto vec_tmp_sum = at::vec::Vectorized(tmp_sum); + long i = 0; + for (; i < vec_size * (size / vec_size); i += vec_size) { + auto tmp0 = at::vec::Vectorized::loadu(a + i); + auto tmp1 = tmp0 - vec_max; + auto tmp2 = tmp1.exp_u20(); + vec_tmp_sum += tmp2; + tmp2.store(out + i); } + tmp_sum = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return x + y; + }, + vec_tmp_sum); + for (; i < size; i++) { + auto tmp0 = a[i]; + auto tmp1 = tmp0 - val; + auto tmp2 = exp(tmp1); + tmp_sum += tmp2; + out[i] = tmp2; + } + val = tmp_sum; } +// 1) out = a * scale +// 2) max = max(out) template -inline void _mul_div_add_softmax( - const scalar_t* a, +inline void _mul_reduce_max_fusion_kernel( + scalar_t* a, const scalar_t& scale, - const float* mask, const int& size, - scalar_t* out) { + scalar_t* out, + scalar_t& max) { auto vec_size = at::vec::Vectorized::size(); auto vec_scale = at::vec::Vectorized(scale); - auto tmp_max = -std::numeric_limits::infinity(); + scalar_t tmp_max = -std::numeric_limits::infinity(); auto vec_tmp_max = at::vec::Vectorized(tmp_max); long i = 0; - // max(a * scale + mask) for (; i < vec_size * (size / vec_size); i += vec_size) { auto tmp0 = at::vec::Vectorized::loadu(a + i); auto tmp1 = tmp0 * vec_scale; - if (mask != nullptr) { - auto tmp_mask = at::vec::Vectorized::loadu(mask + i); - tmp1 = tmp1 + tmp_mask; - } vec_tmp_max = at::vec::maximum(vec_tmp_max, tmp1); tmp1.store(out + i); } + tmp_max = at::vec::vec_reduce_all( + [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + return at::vec::maximum(x, y); + }, + vec_tmp_max); for (; i < size; i++) { auto tmp0 = a[i]; auto tmp1 = tmp0 * scale; - if (mask != nullptr) { - tmp1 = tmp1 + mask[i]; - } tmp_max = std::max(tmp_max, tmp1); out[i] = tmp1; } - auto max = std::max( - tmp_max, - at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, - at::vec::Vectorized& y) { - return at::vec::maximum(x, y); - }, - vec_tmp_max)); - // exp and sum - scalar_t sum = 0; - auto max_vec = at::vec::Vectorized(max); - i = 0; - for (; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(out + i); - auto tmp1 = tmp0 - max_vec; - tmp1 = tmp1.exp_u20(); - sum += at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { - return x + y; - }, - tmp1); - tmp1.store(out + i); - } - for (; i < size; i++) { - auto tmp0 = out[i]; - auto tmp1 = std::exp(tmp0 - max); - sum += tmp1; - out[i] = tmp1; - } - auto scale_vec = at::vec::Vectorized(1.0f / sum); - // normalize - i = 0; - for (; i < vec_size * (size / vec_size); i += vec_size) { - auto tmp0 = at::vec::Vectorized::loadu(out + i); - auto tmp1 = tmp0 * scale_vec; - tmp1.store(out + i); - } - for (; i < size; i++) { - out[i] = out[i] * (1.0f / sum); - } + max = tmp_max; } void reshape_attn_mask_to_4d( @@ -247,7 +252,7 @@ void reshape_attn_mask_to_4d( * This function computes the attention weights and applies the attention * mechanism to obtain the final output. It takes in tensors representing the * query, key cache, value cache, head mapping, scale, block tables, context - * lengths, block size, max context length, and optional alibi slopes. The + * lengths, block size The * output tensor is updated with the computed attention values. * * @param out Output tensor [num_seqs, 1, num_heads, head_size]. @@ -266,8 +271,7 @@ void reshape_attn_mask_to_4d( * @param block_size The block size which means the number of token in every * block. * @param max_context_len Maximum context length. - * @param attn_mask Optional tensor of alibi slopes with the shape of - * (num_heads). + * @param attn_mask Optional tensor of attention_mask */ template void paged_attention_kernel( @@ -281,44 +285,52 @@ void paged_attention_kernel( at::Tensor& context_lens, int64_t block_size, c10::optional attn_mask) { + using accum_t = at::opmath_type; using Vec = at::vec::Vectorized; const auto dtype = query.scalar_type(); const auto accumulate_dtype = at::toOpMathType(dtype); - + auto max_context_len = context_lens.max().item(); auto num_seqs = query.size(0); - auto query_size = query.size(1); - auto num_heads = query.size(2); + auto q_len = query.size(2); + auto num_heads = query.size(1); auto head_size = query.size(3); - auto num_kv_heads = key_cache.size(2); + auto num_kv_heads = key_cache.size(1); auto max_num_blocks_per_seq = block_tables.size(1); - auto max_context_len = context_lens.max().item(); - + auto kv_head_group_size = num_heads / num_kv_heads; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); if (has_attn_mask) { attn_mask.value() = attn_mask.value().to(at::kFloat); reshape_attn_mask_to_4d( - attn_mask.value(), num_seqs, num_heads, query_size, attn_mask.value().size(-1)); + attn_mask.value(), + num_seqs, + num_heads, + q_len, + attn_mask.value().size(-1)); } - auto attn_weights = at::empty( - {num_seqs, num_heads, max_context_len}, - query.options().dtype(accumulate_dtype)); - auto out_ptr = out.data_ptr(); auto query_ptr = query.data_ptr(); auto key_cache_ptr = key_cache.data_ptr(); auto value_cache_ptr = value_cache.data_ptr(); - auto head_mapping_ptr = head_mapping.data_ptr(); auto block_tables_ptr = block_tables.data_ptr(); auto context_lens_ptr = context_lens.data_ptr(); + + auto kv_block_strideN = key_cache.stride(0); + auto kv_block_strideP = key_cache.stride(2); + auto kv_block_strideH = key_cache.stride(1); + + auto out_strideN = out.stride(0); + auto out_strideH = out.stride(1); + auto out_strideS = out.stride(2); + + auto q_strideN = query.stride(0); + auto q_strideH = query.stride(1); + auto q_strideS = query.stride(2); + auto attn_mask_ptr = attn_mask.has_value() ? attn_mask.value().data_ptr() : nullptr; - auto attn_weights_ptr = attn_weights.data_ptr(); - auto kv_block_strideB = key_cache.stride(0); - auto q_stride = query.stride(0); - auto attn_weights_strideB = attn_weights.stride(0); int64_t mStrideB = (has_attn_mask && attn_mask.value().size(0) > 1) ? attn_mask.value().stride(0) : 0; @@ -327,120 +339,219 @@ void paged_attention_kernel( : 0; int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; -#pragma omp parallel for collapse(3) -for (auto token_id = 0; token_id < max_context_len; token_id++) { - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { - for (auto head_id = 0; head_id < num_heads; head_id++) { - auto context_len = context_lens_ptr[seq_id]; - if (token_id >= context_len) - continue; - auto attn_w_pos = attn_weights_ptr + seq_id * attn_weights_strideB + - head_id * max_context_len + token_id; - auto q_ptr_start = query_ptr + seq_id * q_stride + head_id * head_size; - auto block_id = block_tables_ptr - [seq_id * max_num_blocks_per_seq + token_id / block_size]; - auto block_offset = token_id % block_size; - auto k_cache_start = key_cache_ptr + block_id * kv_block_strideB + - block_offset * num_kv_heads * head_size + - head_mapping_ptr[head_id] * head_size; - reduce_head( - q_ptr_start, k_cache_start, attn_w_pos, head_size); + auto max_num_partitions = + (max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + + auto max_logits = at::empty( + {num_seqs, num_heads, max_num_partitions + 1}, + query.options().dtype(accumulate_dtype)); + + auto exp_sum = at::empty( + {num_seqs, num_heads, max_num_partitions + 1}, + query.options().dtype(accumulate_dtype)); + + auto tmp_out = at::empty( + {num_seqs, num_heads, max_num_partitions, head_size}, + query.options().dtype(accumulate_dtype)); + + auto tmp_out_ptr = tmp_out.data_ptr(); + auto max_logits_ptr = max_logits.data_ptr(); + auto exp_sum_ptr = exp_sum.data_ptr(); + + auto max_logits_strideN = max_logits.stride(0); + auto max_logits_strideH = max_logits.stride(1); + auto exp_sum_strideN = exp_sum.stride(0); + auto exp_sum_strideH = exp_sum.stride(1); + auto tmp_out_strideN = tmp_out.stride(0); + auto tmp_out_strideH = tmp_out.stride(1); + auto tmp_out_strideS = tmp_out.stride(2); + #pragma omp parallel for collapse(3) schedule(static, 1) + for (auto partition_id = 0; partition_id < max_num_partitions; + partition_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + auto context_len = context_lens_ptr[seq_id]; + auto partition_start = partition_id * PARTITION_SIZE; + if (partition_start >= context_len) + continue; + auto partition_end = + std::min(partition_start + PARTITION_SIZE, context_len); + auto token_num = partition_end - partition_start; + auto block_num = (token_num + block_size - 1) / block_size; + auto logical_block_start = partition_start / block_size; + auto logical_block_end = logical_block_start + block_num; + auto need_update = block_num > 1; + auto kv_head_id = head_id / kv_head_group_size; + auto q_ptr_start = + query_ptr + seq_id * q_strideN + head_id * q_strideH; + auto max_logits_offset = seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id; + auto exp_sum_offset = seq_id * exp_sum_strideN + + head_id * exp_sum_strideH + partition_id; + auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + + head_id * tmp_out_strideH + partition_id * tmp_out_strideS; + accum_t logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + auto logits_position = 0; + // 1)calculate the matmul(query, key) for this partition + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; + logical_block_id++) { + auto physical_block_id = block_tables_ptr + [seq_id * max_num_blocks_per_seq + logical_block_id]; + auto tokens_in_block = std::min( + block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; + token_id++) { + auto block_offset = token_id - token_start; + auto k_cache_start = key_cache_ptr + + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + + kv_head_id * kv_block_strideH; + reduce_head( + q_ptr_start, + k_cache_start, + &(logits[logits_position]), + head_size); + logits_position++; + } + } + // 2) calculate the max and exp_sum for this partition + auto partition_max = -std::numeric_limits::infinity(); + if (has_attn_mask) { + _scale_attn_mask_fusion_kernel( + logits, + attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH + + partition_start, + token_num, + logits, + scale); + partition_max = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + logits, + token_num); + } else { + _mul_reduce_max_fusion_kernel( + logits, scale, token_num, logits, partition_max); + } + max_logits_ptr[max_logits_offset] = partition_max; + _exp_reduce_sum_fusion_kernel( + logits, token_num, logits, partition_max); + exp_sum_ptr[exp_sum_offset] = partition_max; + + // 3) calculate the matmul(exp(logits-partition_max), value) for this + // partition, need to divide the global exp_sum in the final result. + logits_position = 0; + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; + logical_block_id++) { + auto physical_block_id = block_tables_ptr + [seq_id * max_num_blocks_per_seq + logical_block_id]; + auto tokens_in_block = std::min( + block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; + token_id++) { + auto block_offset = token_id - token_start; + auto v_cache_start = value_cache_ptr + + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + + kv_head_id * kv_block_strideH; + auto accumulated = logits_position > 0; + mul_attenion_weights_and_value_of_head( + logits[logits_position], + v_cache_start, + tmp_out_start, + head_size, + accumulated); + logits_position++; + } + } + } } - } } +// calculate the final output #pragma omp parallel for collapse(2) for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { for (auto head_id = 0; head_id < num_heads; head_id++) { - auto max_val = -10000.0f; - float sum = 0.0f; - auto qid = 0; + auto global_max = -std::numeric_limits::infinity(); + auto global_exp_sum = 0.0; auto context_len = context_lens_ptr[seq_id]; - auto attn_w_start = attn_weights_ptr + seq_id * attn_weights_strideB + - head_id * max_context_len; - auto head_mask_start = has_attn_mask - ? attn_mask_ptr + mStrideB * seq_id + mStrideH * head_id - : nullptr; - _mul_div_add_softmax( - attn_w_start, scale, head_mask_start, context_len, attn_w_start); - } - } - auto thread_numbers = omp_get_max_threads(); - auto private_attn_outs = at::empty( - {thread_numbers, num_seqs, num_heads, head_size}, accumulate_dtype); - auto private_attn_out_flag = - at::zeros({thread_numbers, num_seqs, num_heads}, at::kByte); - - auto flag_access = private_attn_out_flag.accessor(); - auto private_attn_out_ptr = private_attn_outs.data_ptr(); - auto private_attn_out_stride = private_attn_outs.stride(0); -// mul and accumulate -#pragma omp parallel for collapse(3) -for (auto token_id = 0; token_id < max_context_len; token_id++) { - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { - for (auto head_id = 0; head_id < num_heads; head_id++) { - auto context_len = context_lens_ptr[seq_id]; - auto thread_id = omp_get_thread_num(); - if (token_id >= context_len) - continue; - auto attn_w = attn_weights_ptr - [seq_id * attn_weights_strideB + head_id * max_context_len + - token_id]; - auto block_id = block_tables_ptr - [seq_id * max_num_blocks_per_seq + token_id / block_size]; - auto block_offset = token_id % block_size; - auto v_cache_start = value_cache_ptr + block_id * kv_block_strideB + - block_offset * num_kv_heads * head_size + - head_mapping_ptr[head_id] * head_size; - auto attn_out_start = private_attn_out_ptr + - thread_id * private_attn_out_stride + seq_id * q_stride + - head_id * head_size; - mul_attenion_weights_and_value_of_head( - attn_w, - v_cache_start, - attn_out_start, - head_size, - flag_access[thread_id][seq_id][head_id]); - if (flag_access[thread_id][seq_id][head_id] == 0) { - flag_access[thread_id][seq_id][head_id] = 1; - } - } // for token_id - } // for head_id - } // for seq_id - { - RECORD_FUNCTION( - "ipex::paged_attention::reduction_private_result", - c10::ArrayRef({})); -#pragma omp parallel for collapse(2) - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { - for (auto hi = 0; hi < num_heads; hi++) { - auto thr0_head_start = - private_attn_out_ptr + (seq_id * num_heads + hi) * head_size; - if (flag_access[0][seq_id][hi] == 0) { - fill_stub(thr0_head_start, 0.0f, head_size); - } - for (auto thread_id = 1; thread_id < thread_numbers; thread_id++) { - if (flag_access[thread_id][seq_id][hi] == 0) { - continue; - } - auto attn_out_head_stride = thread_id * private_attn_out_stride + - (seq_id * num_heads + hi) * head_size; - auto private_attn_out_start = - private_attn_out_ptr + attn_out_head_stride; + auto partition_num = (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + // calculate the global max and exp_sum for this head + for (auto partition_id = 0; partition_id < max_num_partitions; + partition_id++) { + if (partition_id >= partition_num) + break; + auto max_logit = max_logits_ptr + [seq_id * max_logits_strideN + head_id * max_logits_strideH + + partition_id]; + global_max = std::max(global_max, max_logit); + } + // update the partition 0 result with the global max + auto partition0_out_start = + tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH; + auto max_logit0 = max_logits_ptr + [seq_id * max_logits_strideN + head_id * max_logits_strideH]; + float exp_val = expf(max_logit0 - global_max); + global_exp_sum += + exp_sum_ptr[seq_id * exp_sum_strideN + head_id * exp_sum_strideH] * + exp_val; + at::vec::Vectorized exp_val_vec0(exp_val); + at::vec::map( + [&](auto a) { return a * exp_val_vec0; }, + partition0_out_start, + partition0_out_start, + head_size); + + // accumulate the partition 1 to partition n result into partition 0 + if (partition_num > 1) { + for (auto partition_id = 1; partition_id < partition_num; + partition_id++) { + if (partition_id * PARTITION_SIZE >= context_len) + break; + auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + + head_id * tmp_out_strideH + partition_id * tmp_out_strideS; + auto max_logit = max_logits_ptr + [seq_id * max_logits_strideN + head_id * max_logits_strideH + + partition_id]; + auto exp_sum = exp_sum_ptr + [seq_id * exp_sum_strideN + head_id * exp_sum_strideH + + partition_id]; + exp_val = expf(max_logit - global_max); + global_exp_sum += exp_sum * exp_val; + at::vec::Vectorized exp_val_vec(exp_val); at::vec::map2( - [](Vec a, Vec b) { return a + b; }, - thr0_head_start, - private_attn_out_start, - thr0_head_start, + [&](auto a, auto b) { return a + exp_val_vec * b; }, + partition0_out_start, + partition0_out_start, + tmp_out_start, head_size); } - auto out_start = out_ptr + (seq_id * num_heads + hi) * head_size; - at::vec::map( - [](Vec a) { return a; }, out_start, thr0_head_start, head_size); } + + // copy the partition 0 result into attn_outs + auto attn_out_start = + out_ptr + seq_id * out_strideN + head_id * out_strideH; + float inverse_global_sum = 1.0 / (global_exp_sum + 1e-8); + at::vec::Vectorized inverse_global_sum_vec(inverse_global_sum); + // rescale the partition 0 result with global exp_sum + at::vec::map( + [&](auto a) { return a * inverse_global_sum_vec; }, + partition0_out_start, + partition0_out_start, + head_size); + // copy the partition 0 result into attn_outs + at::vec::map( + [&](auto a) { return a; }, + attn_out_start, + partition0_out_start, + head_size); } } - } // paged_attention_kernel void paged_attention_kernel_impl( @@ -455,7 +566,7 @@ void paged_attention_kernel_impl( int64_t block_size, c10::optional attn_mask) { TORCH_CHECK( - query.size(1) == 1, + query.size(2) == 1, "Paged attention: only seqlen 1 is supported for query"); TORCH_CHECK( query.scalar_type() == key_cache.scalar_type() && From d744dff1a5359858307422a4f461b9b44637385e Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 4 Jun 2024 21:30:18 -0700 Subject: [PATCH 07/19] Update kv cache manager --- torchao/__init__.py | 3 +++ torchao/kv_cache.py | 51 ++++++++++++++++++++++++++------------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/torchao/__init__.py b/torchao/__init__.py index c982e09a0c..3f1db541dc 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -14,9 +14,12 @@ from . import _C from . import ops +from torchao.kv_cache import PagedAttentionCache, PagedTensor __all__ = [ "dtypes", "apply_dynamic_quant", "apply_weight_only_int8_quant", "autoquant", + "PagedAttentionCache", + "PagedTensor" ] diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py index 69254f2770..8a1257f0bd 100644 --- a/torchao/kv_cache.py +++ b/torchao/kv_cache.py @@ -52,9 +52,9 @@ def scaled_dot_product_attention( query = input key_cache = key_tensor.cache value_cache = value_tensor.cache - num_kv_head = key_cache.size(2) - num_queries_per_kv = query.size(2) // num_kv_head - block_size = key_cache.size(1) + num_kv_head = key_cache.size(1) + num_queries_per_kv = query.size(1) // num_kv_head + block_size = key_cache.size(2) block_tables = key_tensor.block_tables head_mapping = torch.repeat_interleave( torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), num_queries_per_kv @@ -102,8 +102,8 @@ def __init__( cache_shape = ( self.num_blocks, - self.block_size, self.num_key_value_heads, + self.block_size, self.head_dim, ) @@ -126,7 +126,9 @@ def __init__( self.block_ref_count = [ 0 ] * self.num_blocks # init the reference count for each physical block - self.block_tables = dict() # mapping logical block to physical blocks for each sequence + self.block_tables = ( + dict() + ) # mapping logical block to physical blocks for each sequence # The follow two states are shared accross layer but only for the current decode step. Need to update for every decode step. self.batch2seq = None # mapping batch index to {seq_id0, seq_id1, ...} to enable prompt sharing. @@ -163,7 +165,7 @@ def allocate(self, batch_size: int, key_len: int) -> None: past_context_len = self.seen_tokens if self.batch2seq is None: self.set_batch2seq_for_prompt_sharing(batch_size, 1) - for i in range(batch_size): + for i in range(batch_size): seq_idx = self.batch2seq[i][0] # Scenario 1: New seqence: allocate blocks for this sequence if seq_idx not in self.block_tables: @@ -207,7 +209,7 @@ def allocate(self, batch_size: int, key_len: int) -> None: self.block_ref_count[new_block] += 1 self.block_ref_count[last_block] -= 1 self._copy_on_write(last_block, new_block) - + slots = [] # the slots for this sequence for j in range(key_len): @@ -301,17 +303,24 @@ def _reshape_and_cache( Returns: None """ - if key_states.shape[-2] != self.num_key_value_heads: - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - key_cache = self.key_caches[layer_idx].view( - -1, self.num_key_value_heads, self.head_dim - ) - value_cache = self.value_caches[layer_idx].view( - -1, self.num_key_value_heads, self.head_dim - ) - key_cache[slot_mapping] = key_states - value_cache[slot_mapping] = value_states + slot_mapping = slot_mapping.to(torch.int) + block_indicies = torch.div(slot_mapping, self.block_size, rounding_mode="floor") + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % self.block_size + block_offsets = block_offsets.cpu().tolist() + batch_size = key_states.size(0) + seq_len = key_states.size(2) + for i in range(batch_size): + for seq_idx in range(seq_len): + block_idx = block_indicies[i][seq_idx] + block_offset = block_offsets[i][seq_idx] + for head_idx in range(self.num_key_value_heads): + self.key_caches[layer_idx][block_idx, head_idx, block_offset, :] = ( + key_states[i, head_idx, seq_idx, :] + ) + self.value_caches[layer_idx][ + block_idx, head_idx, block_offset, : + ] = value_states[i, head_idx, seq_idx, :] def get_seq_length(self, layer_idx: int = 0) -> int: return self.seen_tokens @@ -349,8 +358,8 @@ def update( # self.slots_mapping is not None, # "PagedAttentionCache: Please first call allocate() of this object to get target positions in paged cache before the model.forward().", # ) - #cache key_states & value_states - self._reshape_and_cache(self.slots_mapping, key_states, value_states, layer_idx) + # cache key_states & value_states + self._reshape_and_cache(self.slots_mapping, key_states, value_states, layer_idx) if layer_idx == self.num_layers - 1: self.seen_tokens += cur_len @@ -403,4 +412,4 @@ def reorder_cache(self, beam_idx: torch.Tensor) -> None: for seq_idx in freed_seqs: self._free(seq_idx) self.block_tables = new_block_tables - self.batch2seq = None \ No newline at end of file + self.batch2seq = None From 3ee352fb3c89ebee1895119ead3c4537799dddcb Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 4 Jun 2024 21:32:50 -0700 Subject: [PATCH 08/19] clang-format and recover test/kernel/test_fused_kernels.py --- test/kernel/test_fused_kernels.py | 111 +++++ torchao/csrc/cpu/paged_attention_kernel.cpp | 448 ++++++++------------ 2 files changed, 295 insertions(+), 264 deletions(-) create mode 100644 test/kernel/test_fused_kernels.py diff --git a/test/kernel/test_fused_kernels.py b/test/kernel/test_fused_kernels.py new file mode 100644 index 0000000000..b43abead45 --- /dev/null +++ b/test/kernel/test_fused_kernels.py @@ -0,0 +1,111 @@ +import itertools + +import pytest + +# Skip entire test if triton is not available, otherwise CI failure +try: + import triton +except ImportError: + pytest.skip("triton is not installed", allow_module_level=True) + +import torch +from galore_test_utils import get_kernel, make_copy, make_data + +torch.manual_seed(0) +MAX_DIFF_no_tf32 = 1e-5 +MAX_DIFF_tf32 = 1e-3 + + +def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): + # Copy to use for first run -- needed because of autotuning and inplace ops + ( + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + grad_autotune_copy, + proj_matrix_autotune_copy, + params_autotune_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + # Copy to use for second run to check accuracy + ( + exp_avg_test_copy, + exp_avg2_test_copy, + grad_test_copy, + proj_matrix_test_copy, + params_test_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + print( + f"Running with {grad.shape[0]} x {grad.shape[1]} grad (param) shape, GaLore orthogonal matrix {list(proj_matrix.shape)}, dtype {grad.dtype} and allow_tf32 {allow_tf32}\n" + f"Kernel: {kernel}", + flush=True, + ) + + ref_op = get_kernel("ref") + test_op = get_kernel(kernel) + + # Reference run + ref_out = ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ) + + # Autotune + _ = test_op( + grad_autotune_copy, + proj_matrix_autotune_copy, + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + params_autotune_copy, + store=False, + allow_tf32=allow_tf32, + ) + + # Accuracy run + test_out = test_op( + grad_test_copy, + proj_matrix_test_copy, + exp_avg_test_copy, + exp_avg2_test_copy, + params_test_copy, + store=True, + allow_tf32=allow_tf32, + ) + print("Accuracy:") + + output_names = [ + "adam state - running grad mean", + "adam state - running grad var", + "params (after update)", + ] + MAX_DIFF = MAX_DIFF_tf32 if allow_tf32 else MAX_DIFF_no_tf32 + for name, ref, tt in zip(output_names, ref_out, test_out): + max_diff = (ref - tt).abs().max() + print(f"-> {name}:\n Max err: {max_diff:.6f}") + assert max_diff < MAX_DIFF + + +KERNELS = ["hybrid"] # "fused"] +DTYPES = [torch.float32] # torch.float16 +ROW_DIMS = [4096] +COL_DIMS = [4096] # , 11008] +RANKS = [128] +ALLOW_TF32 = [False] # , True] + +TEST_CONFIGS = list( + itertools.product(KERNELS, DTYPES, ROW_DIMS, COL_DIMS, RANKS, ALLOW_TF32) +) + +# TEST_CONFIGS = TEST_CONFIGS[0:1] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU") +@pytest.mark.parametrize("kernel, dtype, M, N, rank, allow_tf32", TEST_CONFIGS) +def test_galore_fused_kernels(kernel, dtype, M, N, rank, allow_tf32): + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32) diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp index 5aa7686b66..93e6c4af91 100644 --- a/torchao/csrc/cpu/paged_attention_kernel.cpp +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -7,7 +8,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -16,10 +16,10 @@ #endif #include -#include -#include #include #include +#include +#include #define PARTITION_SIZE 256 @@ -28,24 +28,19 @@ namespace torchao { namespace { template -void reduce_head( - const scalar_t* q_ptr_start, - const scalar_t* k_cache_start, - accum_t* attn_w_pos, - int64_t head_size) { - attn_w_pos[0] = 0; +void reduce_head(const scalar_t *q_ptr_start, const scalar_t *k_cache_start, + accum_t *attn_w_pos, int64_t head_size) { + attn_w_pos[0] = 0; for (long i = 0; i < head_size; i++) { attn_w_pos[0] += q_ptr_start[i] * k_cache_start[i]; } } -//BF16 +// BF16 template <> -void reduce_head( - const at::BFloat16* q_ptr_start, - const at::BFloat16* k_cache_start, - float* attn_w_pos, - int64_t head_size) { +void reduce_head(const at::BFloat16 *q_ptr_start, + const at::BFloat16 *k_cache_start, + float *attn_w_pos, int64_t head_size) { attn_w_pos[0] = 0; using lpVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; @@ -55,25 +50,19 @@ void reduce_head( auto tmpq = lpVec::loadu(q_ptr_start + i); auto tmpk = lpVec::loadu(k_cache_start + i); fVec tmpq1, tmpq2, tmpk1, tmpk2; - //convert to float + // convert to float std::tie(tmpq1, tmpq2) = at::vec::convert_to_float(tmpq); std::tie(tmpk1, tmpk2) = at::vec::convert_to_float(tmpk); vec_tmp_sum = vec_tmp_sum + tmpq1 * tmpk1 + tmpq2 * tmpk2; } attn_w_pos[0] = at::vec::vec_reduce_all<>( - [](fVec& x, fVec& y) { - return x + y; - }, - vec_tmp_sum); + [](fVec &x, fVec &y) { return x + y; }, vec_tmp_sum); } template inline void mul_attenion_weights_and_value_of_head( - const accum_t& attn_w, - const scalar_t* v_cache_start, - accum_t* attn_out_start, - int64_t head_size, - bool accumulated) { + const accum_t &attn_w, const scalar_t *v_cache_start, + accum_t *attn_out_start, int64_t head_size, bool accumulated) { for (auto hsi = 0; hsi < head_size; hsi++) { if (accumulated) { attn_out_start[hsi] += attn_w * (float)v_cache_start[hsi]; @@ -85,11 +74,8 @@ inline void mul_attenion_weights_and_value_of_head( template <> inline void mul_attenion_weights_and_value_of_head( - const float& attn_w, - const at::BFloat16* v_cache_start, - float* attn_out_start, - int64_t head_size, - bool accumulated) { + const float &attn_w, const at::BFloat16 *v_cache_start, + float *attn_out_start, int64_t head_size, bool accumulated) { using lpVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; auto lpVec_size = lpVec::size(); @@ -97,10 +83,10 @@ inline void mul_attenion_weights_and_value_of_head( auto vec_attn_w = fVec(attn_w); auto vec_tmp_sum = fVec(0.0f); long i = 0; - for (; i < lpVec_size *(head_size/lpVec_size) ; i += lpVec_size) { + for (; i < lpVec_size * (head_size / lpVec_size); i += lpVec_size) { auto tmpv = lpVec::loadu(v_cache_start + i); fVec tmpv1, tmpv2; - //convert to float + // convert to float std::tie(tmpv1, tmpv2) = at::vec::convert_to_float(tmpv); auto tmp1 = tmpv1 * vec_attn_w; auto tmp2 = tmpv2 * vec_attn_w; @@ -125,12 +111,8 @@ inline void mul_attenion_weights_and_value_of_head( // out = val * a + b template -inline void _scale_attn_mask_fusion_kernel( - T1* a, - float* b, - const int& size, - T2* out, - float val) { +inline void _scale_attn_mask_fusion_kernel(T1 *a, float *b, const int &size, + T2 *out, float val) { const auto vec_size = at::vec::Vectorized::size(); const auto vec_scale = at::vec::Vectorized(val); int64_t i = 0; @@ -150,11 +132,8 @@ inline void _scale_attn_mask_fusion_kernel( // 1) out = exp(a - val) // 2) val = sum(out) template -inline void _exp_reduce_sum_fusion_kernel( - T1* a, - const int& size, - T2* out, - T1& val) { +inline void _exp_reduce_sum_fusion_kernel(T1 *a, const int &size, T2 *out, + T1 &val) { auto vec_size = at::vec::Vectorized::size(); auto vec_max = at::vec::Vectorized(val); T1 tmp_sum = 0; @@ -168,7 +147,7 @@ inline void _exp_reduce_sum_fusion_kernel( tmp2.store(out + i); } tmp_sum = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + [](at::vec::Vectorized &x, at::vec::Vectorized &y) { return x + y; }, vec_tmp_sum); @@ -185,12 +164,9 @@ inline void _exp_reduce_sum_fusion_kernel( // 1) out = a * scale // 2) max = max(out) template -inline void _mul_reduce_max_fusion_kernel( - scalar_t* a, - const scalar_t& scale, - const int& size, - scalar_t* out, - scalar_t& max) { +inline void _mul_reduce_max_fusion_kernel(scalar_t *a, const scalar_t &scale, + const int &size, scalar_t *out, + scalar_t &max) { auto vec_size = at::vec::Vectorized::size(); auto vec_scale = at::vec::Vectorized(scale); scalar_t tmp_max = -std::numeric_limits::infinity(); @@ -203,7 +179,7 @@ inline void _mul_reduce_max_fusion_kernel( tmp1.store(out + i); } tmp_max = at::vec::vec_reduce_all( - [](at::vec::Vectorized& x, at::vec::Vectorized& y) { + [](at::vec::Vectorized &x, at::vec::Vectorized &y) { return at::vec::maximum(x, y); }, vec_tmp_max); @@ -216,12 +192,8 @@ inline void _mul_reduce_max_fusion_kernel( max = tmp_max; } -void reshape_attn_mask_to_4d( - at::Tensor& attn_mask, - int64_t batchSize, - int64_t num_head, - int64_t qSize, - int64_t kvSize) { +void reshape_attn_mask_to_4d(at::Tensor &attn_mask, int64_t batchSize, + int64_t num_head, int64_t qSize, int64_t kvSize) { // Support mask shapes: // 2d: ({Q_seq_len, 1} x {KV_seq_len, 1}) // 4d: ({Batch, 1} x {Num_heads, 1} x {Q_seq_len, 1} x {KV_seq_len, 1}) @@ -237,11 +209,8 @@ void reshape_attn_mask_to_4d( } } attn_mask = attn_mask - .view( - {attn_mask_size_0, - attn_mask_size_1, - attn_mask.size(-2), - attn_mask.size(-1)}) + .view({attn_mask_size_0, attn_mask_size_1, attn_mask.size(-2), + attn_mask.size(-1)}) .expand({attn_mask_size_0, attn_mask_size_1, qSize, kvSize}); } @@ -274,18 +243,13 @@ void reshape_attn_mask_to_4d( * @param attn_mask Optional tensor of attention_mask */ template -void paged_attention_kernel( - at::Tensor& out, - at::Tensor& query, - at::Tensor& key_cache, - at::Tensor& value_cache, - at::Tensor& head_mapping, - const double scale, - at::Tensor& block_tables, - at::Tensor& context_lens, - int64_t block_size, - c10::optional attn_mask) { - +void paged_attention_kernel(at::Tensor &out, at::Tensor &query, + at::Tensor &key_cache, at::Tensor &value_cache, + at::Tensor &head_mapping, const double scale, + at::Tensor &block_tables, at::Tensor &context_lens, + int64_t block_size, + c10::optional attn_mask) { + using accum_t = at::opmath_type; using Vec = at::vec::Vectorized; const auto dtype = query.scalar_type(); @@ -301,12 +265,8 @@ void paged_attention_kernel( bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); if (has_attn_mask) { attn_mask.value() = attn_mask.value().to(at::kFloat); - reshape_attn_mask_to_4d( - attn_mask.value(), - num_seqs, - num_heads, - q_len, - attn_mask.value().size(-1)); + reshape_attn_mask_to_4d(attn_mask.value(), num_seqs, num_heads, q_len, + attn_mask.value().size(-1)); } auto out_ptr = out.data_ptr(); @@ -332,27 +292,24 @@ void paged_attention_kernel( attn_mask.has_value() ? attn_mask.value().data_ptr() : nullptr; int64_t mStrideB = (has_attn_mask && attn_mask.value().size(0) > 1) - ? attn_mask.value().stride(0) - : 0; + ? attn_mask.value().stride(0) + : 0; int64_t mStrideH = (has_attn_mask && attn_mask.value().size(1) > 1) - ? attn_mask.value().stride(1) - : 0; + ? attn_mask.value().stride(1) + : 0; int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; auto max_num_partitions = (max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - auto max_logits = at::empty( - {num_seqs, num_heads, max_num_partitions + 1}, - query.options().dtype(accumulate_dtype)); + auto max_logits = at::empty({num_seqs, num_heads, max_num_partitions + 1}, + query.options().dtype(accumulate_dtype)); - auto exp_sum = at::empty( - {num_seqs, num_heads, max_num_partitions + 1}, - query.options().dtype(accumulate_dtype)); + auto exp_sum = at::empty({num_seqs, num_heads, max_num_partitions + 1}, + query.options().dtype(accumulate_dtype)); - auto tmp_out = at::empty( - {num_seqs, num_heads, max_num_partitions, head_size}, - query.options().dtype(accumulate_dtype)); + auto tmp_out = at::empty({num_seqs, num_heads, max_num_partitions, head_size}, + query.options().dtype(accumulate_dtype)); auto tmp_out_ptr = tmp_out.data_ptr(); auto max_logits_ptr = max_logits.data_ptr(); @@ -364,113 +321,101 @@ void paged_attention_kernel( auto exp_sum_strideH = exp_sum.stride(1); auto tmp_out_strideN = tmp_out.stride(0); auto tmp_out_strideH = tmp_out.stride(1); - auto tmp_out_strideS = tmp_out.stride(2); - #pragma omp parallel for collapse(3) schedule(static, 1) - for (auto partition_id = 0; partition_id < max_num_partitions; - partition_id++) { - for (auto head_id = 0; head_id < num_heads; head_id++) { - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { - auto context_len = context_lens_ptr[seq_id]; - auto partition_start = partition_id * PARTITION_SIZE; - if (partition_start >= context_len) - continue; - auto partition_end = - std::min(partition_start + PARTITION_SIZE, context_len); - auto token_num = partition_end - partition_start; - auto block_num = (token_num + block_size - 1) / block_size; - auto logical_block_start = partition_start / block_size; - auto logical_block_end = logical_block_start + block_num; - auto need_update = block_num > 1; - auto kv_head_id = head_id / kv_head_group_size; - auto q_ptr_start = - query_ptr + seq_id * q_strideN + head_id * q_strideH; - auto max_logits_offset = seq_id * max_logits_strideN + - head_id * max_logits_strideH + partition_id; - auto exp_sum_offset = seq_id * exp_sum_strideN + - head_id * exp_sum_strideH + partition_id; - auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + - head_id * tmp_out_strideH + partition_id * tmp_out_strideS; - accum_t logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; - auto logits_position = 0; - // 1)calculate the matmul(query, key) for this partition - for (auto logical_block_id = logical_block_start; - logical_block_id < logical_block_end; - logical_block_id++) { - auto physical_block_id = block_tables_ptr - [seq_id * max_num_blocks_per_seq + logical_block_id]; - auto tokens_in_block = std::min( - block_size, context_len - logical_block_id * block_size); - auto token_start = logical_block_id * block_size; - auto token_end = token_start + tokens_in_block; - for (auto token_id = token_start; token_id < token_end; - token_id++) { - auto block_offset = token_id - token_start; - auto k_cache_start = key_cache_ptr + - physical_block_id * kv_block_strideN + - block_offset * kv_block_strideP + - kv_head_id * kv_block_strideH; - reduce_head( - q_ptr_start, - k_cache_start, - &(logits[logits_position]), - head_size); - logits_position++; - } - } - // 2) calculate the max and exp_sum for this partition - auto partition_max = -std::numeric_limits::infinity(); - if (has_attn_mask) { - _scale_attn_mask_fusion_kernel( - logits, - attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH + - partition_start, - token_num, - logits, - scale); - partition_max = at::vec::reduce_all( - [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, - logits, - token_num); - } else { - _mul_reduce_max_fusion_kernel( - logits, scale, token_num, logits, partition_max); + auto tmp_out_strideS = tmp_out.stride(2); +#pragma omp parallel for collapse(3) schedule(static, 1) + for (auto partition_id = 0; partition_id < max_num_partitions; + partition_id++) { + for (auto head_id = 0; head_id < num_heads; head_id++) { + for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + auto context_len = context_lens_ptr[seq_id]; + auto partition_start = partition_id * PARTITION_SIZE; + if (partition_start >= context_len) + continue; + auto partition_end = + std::min(partition_start + PARTITION_SIZE, context_len); + auto token_num = partition_end - partition_start; + auto block_num = (token_num + block_size - 1) / block_size; + auto logical_block_start = partition_start / block_size; + auto logical_block_end = logical_block_start + block_num; + auto need_update = block_num > 1; + auto kv_head_id = head_id / kv_head_group_size; + auto q_ptr_start = query_ptr + seq_id * q_strideN + head_id * q_strideH; + auto max_logits_offset = seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id; + auto exp_sum_offset = + seq_id * exp_sum_strideN + head_id * exp_sum_strideH + partition_id; + auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + + head_id * tmp_out_strideH + + partition_id * tmp_out_strideS; + accum_t logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + auto logits_position = 0; + // 1)calculate the matmul(query, key) for this partition + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; logical_block_id++) { + auto physical_block_id = + block_tables_ptr[seq_id * max_num_blocks_per_seq + + logical_block_id]; + auto tokens_in_block = + std::min(block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; token_id++) { + auto block_offset = token_id - token_start; + auto k_cache_start = + key_cache_ptr + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + kv_head_id * kv_block_strideH; + reduce_head(q_ptr_start, k_cache_start, + &(logits[logits_position]), + head_size); + logits_position++; } - max_logits_ptr[max_logits_offset] = partition_max; - _exp_reduce_sum_fusion_kernel( - logits, token_num, logits, partition_max); - exp_sum_ptr[exp_sum_offset] = partition_max; - - // 3) calculate the matmul(exp(logits-partition_max), value) for this - // partition, need to divide the global exp_sum in the final result. - logits_position = 0; - for (auto logical_block_id = logical_block_start; - logical_block_id < logical_block_end; - logical_block_id++) { - auto physical_block_id = block_tables_ptr - [seq_id * max_num_blocks_per_seq + logical_block_id]; - auto tokens_in_block = std::min( - block_size, context_len - logical_block_id * block_size); - auto token_start = logical_block_id * block_size; - auto token_end = token_start + tokens_in_block; - for (auto token_id = token_start; token_id < token_end; - token_id++) { - auto block_offset = token_id - token_start; - auto v_cache_start = value_cache_ptr + - physical_block_id * kv_block_strideN + - block_offset * kv_block_strideP + - kv_head_id * kv_block_strideH; - auto accumulated = logits_position > 0; - mul_attenion_weights_and_value_of_head( - logits[logits_position], - v_cache_start, - tmp_out_start, - head_size, - accumulated); - logits_position++; - } + } + // 2) calculate the max and exp_sum for this partition + auto partition_max = -std::numeric_limits::infinity(); + if (has_attn_mask) { + _scale_attn_mask_fusion_kernel( + logits, + attn_mask_ptr + seq_id * mStrideB + head_id * mStrideH + + partition_start, + token_num, logits, scale); + partition_max = at::vec::reduce_all( + [](Vec &x, Vec &y) { return at::vec::maximum(x, y); }, logits, + token_num); + } else { + _mul_reduce_max_fusion_kernel(logits, scale, token_num, + logits, partition_max); + } + max_logits_ptr[max_logits_offset] = partition_max; + _exp_reduce_sum_fusion_kernel(logits, token_num, + logits, partition_max); + exp_sum_ptr[exp_sum_offset] = partition_max; + + // 3) calculate the matmul(exp(logits-partition_max), value) for this + // partition, need to divide the global exp_sum in the final result. + logits_position = 0; + for (auto logical_block_id = logical_block_start; + logical_block_id < logical_block_end; logical_block_id++) { + auto physical_block_id = + block_tables_ptr[seq_id * max_num_blocks_per_seq + + logical_block_id]; + auto tokens_in_block = + std::min(block_size, context_len - logical_block_id * block_size); + auto token_start = logical_block_id * block_size; + auto token_end = token_start + tokens_in_block; + for (auto token_id = token_start; token_id < token_end; token_id++) { + auto block_offset = token_id - token_start; + auto v_cache_start = + value_cache_ptr + physical_block_id * kv_block_strideN + + block_offset * kv_block_strideP + kv_head_id * kv_block_strideH; + auto accumulated = logits_position > 0; + mul_attenion_weights_and_value_of_head( + logits[logits_position], v_cache_start, tmp_out_start, + head_size, accumulated); + logits_position++; } } } + } } // calculate the final output @@ -486,26 +431,24 @@ void paged_attention_kernel( partition_id++) { if (partition_id >= partition_num) break; - auto max_logit = max_logits_ptr - [seq_id * max_logits_strideN + head_id * max_logits_strideH + - partition_id]; + auto max_logit = + max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id]; global_max = std::max(global_max, max_logit); } // update the partition 0 result with the global max auto partition0_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH; - auto max_logit0 = max_logits_ptr - [seq_id * max_logits_strideN + head_id * max_logits_strideH]; + auto max_logit0 = max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH]; float exp_val = expf(max_logit0 - global_max); global_exp_sum += exp_sum_ptr[seq_id * exp_sum_strideN + head_id * exp_sum_strideH] * exp_val; at::vec::Vectorized exp_val_vec0(exp_val); - at::vec::map( - [&](auto a) { return a * exp_val_vec0; }, - partition0_out_start, - partition0_out_start, - head_size); + at::vec::map([&](auto a) { return a * exp_val_vec0; }, + partition0_out_start, partition0_out_start, + head_size); // accumulate the partition 1 to partition n result into partition 0 if (partition_num > 1) { @@ -514,21 +457,19 @@ void paged_attention_kernel( if (partition_id * PARTITION_SIZE >= context_len) break; auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + - head_id * tmp_out_strideH + partition_id * tmp_out_strideS; - auto max_logit = max_logits_ptr - [seq_id * max_logits_strideN + head_id * max_logits_strideH + - partition_id]; - auto exp_sum = exp_sum_ptr - [seq_id * exp_sum_strideN + head_id * exp_sum_strideH + - partition_id]; + head_id * tmp_out_strideH + + partition_id * tmp_out_strideS; + auto max_logit = + max_logits_ptr[seq_id * max_logits_strideN + + head_id * max_logits_strideH + partition_id]; + auto exp_sum = exp_sum_ptr[seq_id * exp_sum_strideN + + head_id * exp_sum_strideH + partition_id]; exp_val = expf(max_logit - global_max); global_exp_sum += exp_sum * exp_val; at::vec::Vectorized exp_val_vec(exp_val); at::vec::map2( [&](auto a, auto b) { return a + exp_val_vec * b; }, - partition0_out_start, - partition0_out_start, - tmp_out_start, + partition0_out_start, partition0_out_start, tmp_out_start, head_size); } } @@ -539,44 +480,36 @@ void paged_attention_kernel( float inverse_global_sum = 1.0 / (global_exp_sum + 1e-8); at::vec::Vectorized inverse_global_sum_vec(inverse_global_sum); // rescale the partition 0 result with global exp_sum - at::vec::map( - [&](auto a) { return a * inverse_global_sum_vec; }, - partition0_out_start, - partition0_out_start, - head_size); + at::vec::map([&](auto a) { return a * inverse_global_sum_vec; }, + partition0_out_start, partition0_out_start, + head_size); // copy the partition 0 result into attn_outs - at::vec::map( - [&](auto a) { return a; }, - attn_out_start, - partition0_out_start, - head_size); + at::vec::map([&](auto a) { return a; }, attn_out_start, + partition0_out_start, head_size); } } } // paged_attention_kernel void paged_attention_kernel_impl( - at::Tensor& out, // [num_seqs, 1, num_heads, head_size] - at::Tensor& query, // [num_seqs, 1, num_heads, head_size] - at::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] - at::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] - at::Tensor& head_mapping, // [num_heads] + at::Tensor &out, // [num_seqs, 1, num_heads, head_size] + at::Tensor &query, // [num_seqs, 1, num_heads, head_size] + at::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size] + at::Tensor &value_cache, // [num_blocks, block_size, num_heads, head_size] + at::Tensor &head_mapping, // [num_heads] const double scale, - at::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - at::Tensor& context_lens, // [num_seqs] - int64_t block_size, - c10::optional attn_mask) { - TORCH_CHECK( - query.size(2) == 1, - "Paged attention: only seqlen 1 is supported for query"); - TORCH_CHECK( - query.scalar_type() == key_cache.scalar_type() && - query.scalar_type() == value_cache.scalar_type(), - "Paged attention: Q/K/V should have the same data type"); - TORCH_CHECK( - !attn_mask.has_value() || - query.scalar_type() == attn_mask.value().scalar_type() || - attn_mask.value().scalar_type() != at::ScalarType::Bool, - "Paged attention: Mask should have the same data type as Q/K/V and should not be Bool"); + at::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + at::Tensor &context_lens, // [num_seqs] + int64_t block_size, c10::optional attn_mask) { + TORCH_CHECK(query.size(2) == 1, + "Paged attention: only seqlen 1 is supported for query"); + TORCH_CHECK(query.scalar_type() == key_cache.scalar_type() && + query.scalar_type() == value_cache.scalar_type(), + "Paged attention: Q/K/V should have the same data type"); + TORCH_CHECK(!attn_mask.has_value() || + query.scalar_type() == attn_mask.value().scalar_type() || + attn_mask.value().scalar_type() != at::ScalarType::Bool, + "Paged attention: Mask should have the same data type as Q/K/V " + "and should not be Bool"); TORCH_CHECK( query.dim() == 4 && key_cache.dim() == 4 && value_cache.dim() == 4, "Paged attention: Accept only 4 dims inputs shape of {B, H, T, K}"); @@ -587,26 +520,13 @@ void paged_attention_kernel_impl( "Paged attention: Q/KV cache/Mask should be continuous on the last dim"); AT_DISPATCH_FLOATING_TYPES_AND2( - at::kBFloat16, - at::kHalf, - query.scalar_type(), - "paged_attention", - [&] { - paged_attention_kernel( - out, - query, - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - context_lens, - block_size, - attn_mask); + at::kBFloat16, at::kHalf, query.scalar_type(), "paged_attention", [&] { + paged_attention_kernel(out, query, key_cache, value_cache, + head_mapping, scale, block_tables, + context_lens, block_size, attn_mask); }); } - } // namespace TORCH_LIBRARY_IMPL(torchao, CPU, m) { m.impl("torchao::paged_attention", &paged_attention_kernel_impl); From 302dd70261760bc33b860c1907bbf090e0f77cc8 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Wed, 5 Jun 2024 12:38:31 +0800 Subject: [PATCH 09/19] Update test_paged_attention.py --- test/kernel/test_paged_attention.py | 217 +--------------------------- 1 file changed, 1 insertion(+), 216 deletions(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index 6e68ad5516..498ee01942 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -202,220 +202,5 @@ def test_paged_attention_kv_cache(self): ) -# class PagedAttentionTest(unittest.TestCase): -# def create_kv_caches( -# self, -# num_blocks: int, -# block_size: int, -# num_layer: int, -# num_head: int, -# head_size: int, -# dtype: torch.dtype, -# seed: int, -# ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: -# torch.random.manual_seed(seed) -# torch.manual_seed(seed) - -# scale = head_size**-0.5 -# key_cache_shape = (num_blocks, num_head, block_size, head_size) -# key_caches = [] -# for _ in range(num_layer): -# key_cache = torch.empty(size=key_cache_shape, dtype=dtype) -# key_cache.uniform_(-scale, scale) -# key_caches.append(key_cache) - -# value_cache_shape = (num_blocks, num_head, block_size, head_size) -# value_caches = [] -# for _ in range(num_layer): -# value_cache = torch.empty(size=value_cache_shape, dtype=dtype) -# value_cache.uniform_(-scale, scale) -# value_caches.append(value_cache) -# return key_caches, value_caches - -# def ref_masked_attention( -# self, -# query: torch.Tensor, -# key: torch.Tensor, -# value: torch.Tensor, -# scale: float, -# attn_mask: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() -# attn_weights = attn_weights * scale -# if attn_mask is not None: -# attn_weights = attn_weights + attn_mask.float() -# attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) -# out = torch.einsum("hqk,khd->qhd", attn_weights, value) -# return out - -# def ref_single_query_cached_kv_attention( -# self, -# output: torch.Tensor, -# query: torch.Tensor, -# num_queries_per_kv: int, -# key_cache: torch.Tensor, -# value_cache: torch.Tensor, -# block_tables: torch.Tensor, -# context_lens: torch.Tensor, -# scale: float, -# attn_mask: Optional[torch.Tensor], -# ) -> None: -# num_query_heads = query.shape[1] -# num_kv_head = value_cache.shape[1] -# head_size = value_cache.shape[3] -# block_size = value_cache.shape[2] -# num_seqs = query.shape[0] - -# block_tables = block_tables.cpu().tolist() -# context_lens = context_lens.cpu().tolist() -# for i in range(num_seqs): -# q = query[i].unsqueeze(0) -# block_table = block_tables[i] -# context_len = int(context_lens[i]) - -# keys = [] -# values = [] -# for j in range(context_len): -# key = torch.empty( -# num_kv_head, head_size, dtype=query.dtype, device="cpu" -# ) -# value = torch.empty( -# num_kv_head, head_size, dtype=query.dtype, device="cpu" -# ) -# for k in range(num_kv_head): -# block_number = int(block_table[j // block_size]) -# block_offset = j % block_size -# key[k, :] = key_cache[block_number, k, block_offset, :] -# value[k, :] = value_cache[block_number, k, block_offset, :] -# keys.append(key) -# values.append(value) -# keys = torch.stack(keys, dim=0) -# values = torch.stack(values, dim=0) -# if num_queries_per_kv > 1: -# # Handle MQA and GQA -# keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) -# values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) -# # out = self.ref_masked_attention(q, keys, values, scale, attn_mask[i]) -# out = self.ref_masked_attention(q, keys, values, scale, None) -# out = out.view(num_query_heads, head_size) -# output[i].copy_(out, non_blocking=True) - -# def _test_paged_attention_func( -# self, -# num_seqs: int, -# num_head: Tuple[int, int], -# head_size: int, -# num_blocks: int, -# block_size: int, -# dtype: torch.dtype, -# seed: int, -# ) -> None: -# random.seed(seed) -# torch.random.manual_seed(seed) -# torch.manual_seed(seed) -# max_seq_len = 512 -# scale = float(1.0 / (head_size**0.5)) -# num_query_heads, num_kv_head = num_head -# query = torch.empty( -# num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" -# ) -# query.uniform_(-scale, scale) -# assert num_query_heads % num_kv_head == 0 -# num_queries_per_kv = num_query_heads // num_kv_head -# head_mapping = torch.repeat_interleave( -# torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), -# num_queries_per_kv, -# ) -# attn_mask = None -# context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] -# context_lens[-1] = max_seq_len -# max_context_len = max_seq_len # max(context_lens) -# attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) -# for i in range(num_seqs): -# attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) -# paded_context_lens = torch.tensor( -# [max_context_len for _ in range(num_seqs)] -# ).to(torch.int32) -# context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") - -# # Create the block tables.NUM_PREFILL_SEQS -# max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size -# block_tables = [] -# for _ in range(num_seqs): -# block_table = [ -# random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) -# ] -# block_tables.append(block_table) -# block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") - -# # Create the KV caches. -# key_caches, value_caches = self.create_kv_caches( -# num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed -# ) -# key_cache, value_cache = key_caches[0], value_caches[0] -# output = torch.empty_like(query.unsqueeze(2)) -# torch.ops.torchao.paged_attention( -# output, -# query.unsqueeze(2), -# key_cache, -# value_cache, -# head_mapping, -# scale, -# block_tables, -# paded_context_lens, -# block_size, -# attn_mask, -# ) -# output = output.squeeze(2) -# #Run the reference implementation. -# ref_output = torch.empty_like(query) -# self.ref_single_query_cached_kv_attention( -# ref_output, -# query, -# num_queries_per_kv, -# key_cache, -# value_cache, -# block_tables, -# context_lens, -# scale, -# attn_mask, -# ) -# assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) - -# def test_paged_attention(self): -# num_blocks = 128 -# dtypes = [torch.bfloat16, torch.float, torch.float16] -# num_gen_seqs = [2] # Arbitrary values for testing -# num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing -# head_sizes = [64, 256] -# block_sizes = [16, 32] -# seeds = [0] -# for ( -# num_seqs, -# num_head, -# head_size, -# block_size, -# dtype, -# seed, -# ) in product( -# num_gen_seqs, -# num_heads, -# head_sizes, -# block_sizes, -# dtypes, -# seeds, -# ): -# pass -# self._test_paged_attention_func( -# num_seqs, -# num_head, -# head_size, -# num_blocks, -# block_size, -# dtype, -# seed, -# ) - - if __name__ == "__main__": - test = unittest.main() \ No newline at end of file + test = unittest.main() From 52d5924f678b3e925cf0872b66f267c95f59b5ea Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Thu, 13 Jun 2024 13:16:13 +0800 Subject: [PATCH 10/19] Update test_paged_attention.py --- test/kernel/test_paged_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index 498ee01942..b5dc3573f7 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -166,7 +166,7 @@ def test_paged_attention_kv_cache(self): num_query_heads = [40] num_key_value_heads = [40, 10, 1] head_dim = [64, 128] - device = ["cpu"] + device = ['cpu'] dtypes = [torch.float, torch.float16, torch.bfloat16] batch_size = [1, 8] beam_size = [1, 4] From 0511027ab207c16b496192a3c2330b778cf9b7c3 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Wed, 19 Jun 2024 10:04:13 +0800 Subject: [PATCH 11/19] Update test/kernel/test_paged_attention.py Co-authored-by: Jiong Gong --- test/kernel/test_paged_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index b5dc3573f7..ed9ba1451e 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -6,7 +6,7 @@ import torchao from torchao.kv_cache import PagedAttentionCache, PagedTensor -class NiaveCache: +class NaiveCache: def __init__(self): self.past_key = None self.past_value = None From 51ad5cb5f53c36887ce79b8e2a7e5fa5e0709212 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Tue, 4 Jun 2024 22:04:37 -0700 Subject: [PATCH 12/19] Refine code --- test/kernel/test_paged_attention.py | 3 +- torchao/csrc/cpu/paged_attention_kernel.cpp | 73 +++++++++------------ torchao/csrc/paged_attention.cpp | 2 +- torchao/kv_cache.py | 8 --- 4 files changed, 32 insertions(+), 54 deletions(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index ed9ba1451e..f9c7ba2a5d 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -71,7 +71,7 @@ def forward(self, inputs, kv_cache): ) return output - +@unittest.skipIf(torch.cuda.is_available(), "CUDA is not enabled yet") class PagedAttentionCachePagedTensorTest(unittest.TestCase): def _test_paged_attention_cache( self, @@ -201,6 +201,5 @@ def test_paged_attention_kv_cache(self): beam_size, ) - if __name__ == "__main__": test = unittest.main() diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp index 93e6c4af91..cec776dac7 100644 --- a/torchao/csrc/cpu/paged_attention_kernel.cpp +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -215,39 +215,24 @@ void reshape_attn_mask_to_4d(at::Tensor &attn_mask, int64_t batchSize, } /** - * Performs scale-dot-product for the next token based on cached key-value - * attention. - * - * This function computes the attention weights and applies the attention - * mechanism to obtain the final output. It takes in tensors representing the - * query, key cache, value cache, head mapping, scale, block tables, context - * lengths, block size The - * output tensor is updated with the computed attention values. - * - * @param out Output tensor [num_seqs, 1, num_heads, head_size]. - * @param query Query tensor [num_seqs, 1, num_heads, head_size]. + * Performs scale-dot-product for the next token based on paged cached key-value + * @param out Output tensor [batch_size, num_heads, 1, head_size]. + * @param query Query tensor [batch_size, num_heads, 1, head_size]. * @param key_cache The pre-allocated buffer to store the key cache. The - * shape should be [num_blocks, block_size, num_heads, head_size]. + * shape should be [num_blocks, num_heads, block_size, head_size]. * @param value_cache The pre-allocated buffer to store the value cache. The - * shape should be [num_blocks, block_size, num_heads, head_size]. - * @param head_mapping Head mapping tensor [num_heads]. The mapping from the - * query head to the kv head to support GQA/MQA. The shape should be the number - * of query heads. + * shape should be [num_blocks, num_heads, block_size, head_size]. * @param scale Scaling factor for attention weights. In general, it is: * float(1.0 / (head_size ** 0.5)). - * @param block_tables Block tables tensor [num_seqs, max_num_blocks_per_seq]. - * @param context_lens Context lengths tensor [num_seqs]. - * @param block_size The block size which means the number of token in every - * block. - * @param max_context_len Maximum context length. + * @param block_tables Block tables tensor [batch_size, max_num_blocks_per_seq]. + * @param context_lens Context lengths tensor [batch_size]. * @param attn_mask Optional tensor of attention_mask */ template void paged_attention_kernel(at::Tensor &out, at::Tensor &query, at::Tensor &key_cache, at::Tensor &value_cache, - at::Tensor &head_mapping, const double scale, - at::Tensor &block_tables, at::Tensor &context_lens, - int64_t block_size, + const double scale, at::Tensor &block_tables, + at::Tensor &context_lens, c10::optional attn_mask) { using accum_t = at::opmath_type; @@ -255,17 +240,18 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, const auto dtype = query.scalar_type(); const auto accumulate_dtype = at::toOpMathType(dtype); auto max_context_len = context_lens.max().item(); - auto num_seqs = query.size(0); + auto batch_size = query.size(0); auto q_len = query.size(2); auto num_heads = query.size(1); auto head_size = query.size(3); + auto block_size = key_cache.size(2); auto num_kv_heads = key_cache.size(1); auto max_num_blocks_per_seq = block_tables.size(1); auto kv_head_group_size = num_heads / num_kv_heads; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); if (has_attn_mask) { attn_mask.value() = attn_mask.value().to(at::kFloat); - reshape_attn_mask_to_4d(attn_mask.value(), num_seqs, num_heads, q_len, + reshape_attn_mask_to_4d(attn_mask.value(), batch_size, num_heads, q_len, attn_mask.value().size(-1)); } @@ -298,17 +284,17 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, ? attn_mask.value().stride(1) : 0; int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; - + auto max_num_partitions = (max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; - auto max_logits = at::empty({num_seqs, num_heads, max_num_partitions + 1}, + auto max_logits = at::empty({batch_size, num_heads, max_num_partitions + 1}, query.options().dtype(accumulate_dtype)); - auto exp_sum = at::empty({num_seqs, num_heads, max_num_partitions + 1}, + auto exp_sum = at::empty({batch_size, num_heads, max_num_partitions + 1}, query.options().dtype(accumulate_dtype)); - auto tmp_out = at::empty({num_seqs, num_heads, max_num_partitions, head_size}, + auto tmp_out = at::empty({batch_size, num_heads, max_num_partitions, head_size}, query.options().dtype(accumulate_dtype)); auto tmp_out_ptr = tmp_out.data_ptr(); @@ -326,7 +312,7 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, for (auto partition_id = 0; partition_id < max_num_partitions; partition_id++) { for (auto head_id = 0; head_id < num_heads; head_id++) { - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto seq_id = 0; seq_id < batch_size; seq_id++) { auto context_len = context_lens_ptr[seq_id]; auto partition_start = partition_id * PARTITION_SIZE; if (partition_start >= context_len) @@ -347,7 +333,7 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH + partition_id * tmp_out_strideS; - accum_t logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; + accum_t alignas(64) logits[PARTITION_SIZE] = {0}; auto logits_position = 0; // 1)calculate the matmul(query, key) for this partition for (auto logical_block_id = logical_block_start; @@ -420,7 +406,7 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, // calculate the final output #pragma omp parallel for collapse(2) - for (auto seq_id = 0; seq_id < num_seqs; seq_id++) { + for (auto seq_id = 0; seq_id < batch_size; seq_id++) { for (auto head_id = 0; head_id < num_heads; head_id++) { auto global_max = -std::numeric_limits::infinity(); auto global_exp_sum = 0.0; @@ -491,15 +477,16 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, } // paged_attention_kernel void paged_attention_kernel_impl( - at::Tensor &out, // [num_seqs, 1, num_heads, head_size] - at::Tensor &query, // [num_seqs, 1, num_heads, head_size] - at::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size] - at::Tensor &value_cache, // [num_blocks, block_size, num_heads, head_size] - at::Tensor &head_mapping, // [num_heads] + at::Tensor &out, // [batch_size, num_heads, 1, head_size] + at::Tensor &query, // [batch_size, num_heads, 1, head_size] + at::Tensor &key_cache, // [num_blocks, num_heads, block_size, head_size] + at::Tensor &value_cache, // [num_blocks, num_heads, block_size, head_size] const double scale, - at::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] - at::Tensor &context_lens, // [num_seqs] - int64_t block_size, c10::optional attn_mask) { + at::Tensor &block_tables, // [batch_size, max_num_blocks_per_seq] + at::Tensor &context_lens, // [batch_size] + c10::optional attn_mask) { + TORCH_CHECK(PARTITION_SIZE % key_cache.size(2) == 0, + "Paged attention: The PARTION_SIZE:%d should be divisible by block_size: %d", PARTITION_SIZE, key_cache.size(2)); TORCH_CHECK(query.size(2) == 1, "Paged attention: only seqlen 1 is supported for query"); TORCH_CHECK(query.scalar_type() == key_cache.scalar_type() && @@ -522,8 +509,8 @@ void paged_attention_kernel_impl( AT_DISPATCH_FLOATING_TYPES_AND2( at::kBFloat16, at::kHalf, query.scalar_type(), "paged_attention", [&] { paged_attention_kernel(out, query, key_cache, value_cache, - head_mapping, scale, block_tables, - context_lens, block_size, attn_mask); + scale, block_tables, + context_lens, attn_mask); }); } diff --git a/torchao/csrc/paged_attention.cpp b/torchao/csrc/paged_attention.cpp index 987c3a55fa..f6890efe22 100644 --- a/torchao/csrc/paged_attention.cpp +++ b/torchao/csrc/paged_attention.cpp @@ -6,6 +6,6 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def( "paged_attention(Tensor (a!)out, Tensor (a!)query, Tensor (a!)key_cache, Tensor (a!)value_cache,\ - Tensor(a!) head_mapping, float scale, Tensor(a!) block_tables, Tensor(a!) context_lens, int block_size,\ + float scale, Tensor(a!) block_tables, Tensor(a!) context_lens, \ Tensor? attn_mask)-> ()"); } \ No newline at end of file diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py index 8a1257f0bd..94921c66a2 100644 --- a/torchao/kv_cache.py +++ b/torchao/kv_cache.py @@ -52,13 +52,7 @@ def scaled_dot_product_attention( query = input key_cache = key_tensor.cache value_cache = value_tensor.cache - num_kv_head = key_cache.size(1) - num_queries_per_kv = query.size(1) // num_kv_head - block_size = key_cache.size(2) block_tables = key_tensor.block_tables - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), num_queries_per_kv - ) context_lens = key_tensor.context_lens output = torch.empty_like(query) torch.ops.torchao.paged_attention( @@ -66,11 +60,9 @@ def scaled_dot_product_attention( query, key_cache, value_cache, - head_mapping, scale, block_tables, context_lens, - block_size, attn_mask, ) return output From b48b5ff30730ffc417a65b60d6fbb1156c9a168e Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Sun, 23 Jun 2024 19:06:17 -0700 Subject: [PATCH 13/19] Update according to the review suggestions. --- torchao/csrc/cpu/paged_attention_kernel.cpp | 34 +++++++++++---------- torchao/kv_cache.py | 6 ++-- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/torchao/csrc/cpu/paged_attention_kernel.cpp b/torchao/csrc/cpu/paged_attention_kernel.cpp index cec776dac7..a094b9a058 100644 --- a/torchao/csrc/cpu/paged_attention_kernel.cpp +++ b/torchao/csrc/cpu/paged_attention_kernel.cpp @@ -21,7 +21,7 @@ #include #include -#define PARTITION_SIZE 256 +#define SEQ_PARTITION_SIZE 256 namespace torchao { @@ -234,7 +234,9 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, const double scale, at::Tensor &block_tables, at::Tensor &context_lens, c10::optional attn_mask) { - + + TORCH_CHECK(query.size(2) == 1, + "Paged attention: only seqlen 1 is supported for query"); using accum_t = at::opmath_type; using Vec = at::vec::Vectorized; const auto dtype = query.scalar_type(); @@ -285,16 +287,16 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, : 0; int64_t mStrideM = has_attn_mask ? attn_mask.value().stride(2) : 0; - auto max_num_partitions = - (max_context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + auto max_num_seq_partitions = + (max_context_len + SEQ_PARTITION_SIZE - 1) / SEQ_PARTITION_SIZE; - auto max_logits = at::empty({batch_size, num_heads, max_num_partitions + 1}, + auto max_logits = at::empty({batch_size, num_heads, max_num_seq_partitions + 1}, query.options().dtype(accumulate_dtype)); - auto exp_sum = at::empty({batch_size, num_heads, max_num_partitions + 1}, + auto exp_sum = at::empty({batch_size, num_heads, max_num_seq_partitions + 1}, query.options().dtype(accumulate_dtype)); - auto tmp_out = at::empty({batch_size, num_heads, max_num_partitions, head_size}, + auto tmp_out = at::empty({batch_size, num_heads, max_num_seq_partitions, head_size}, query.options().dtype(accumulate_dtype)); auto tmp_out_ptr = tmp_out.data_ptr(); @@ -309,16 +311,16 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, auto tmp_out_strideH = tmp_out.stride(1); auto tmp_out_strideS = tmp_out.stride(2); #pragma omp parallel for collapse(3) schedule(static, 1) - for (auto partition_id = 0; partition_id < max_num_partitions; + for (auto partition_id = 0; partition_id < max_num_seq_partitions; partition_id++) { for (auto head_id = 0; head_id < num_heads; head_id++) { for (auto seq_id = 0; seq_id < batch_size; seq_id++) { auto context_len = context_lens_ptr[seq_id]; - auto partition_start = partition_id * PARTITION_SIZE; + auto partition_start = partition_id * SEQ_PARTITION_SIZE; if (partition_start >= context_len) continue; auto partition_end = - std::min(partition_start + PARTITION_SIZE, context_len); + std::min(partition_start + SEQ_PARTITION_SIZE, context_len); auto token_num = partition_end - partition_start; auto block_num = (token_num + block_size - 1) / block_size; auto logical_block_start = partition_start / block_size; @@ -333,7 +335,7 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH + partition_id * tmp_out_strideS; - accum_t alignas(64) logits[PARTITION_SIZE] = {0}; + accum_t alignas(64) logits[SEQ_PARTITION_SIZE] = {0}; auto logits_position = 0; // 1)calculate the matmul(query, key) for this partition for (auto logical_block_id = logical_block_start; @@ -411,9 +413,9 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, auto global_max = -std::numeric_limits::infinity(); auto global_exp_sum = 0.0; auto context_len = context_lens_ptr[seq_id]; - auto partition_num = (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + auto partition_num = (context_len + SEQ_PARTITION_SIZE - 1) / SEQ_PARTITION_SIZE; // calculate the global max and exp_sum for this head - for (auto partition_id = 0; partition_id < max_num_partitions; + for (auto partition_id = 0; partition_id < max_num_seq_partitions; partition_id++) { if (partition_id >= partition_num) break; @@ -440,7 +442,7 @@ void paged_attention_kernel(at::Tensor &out, at::Tensor &query, if (partition_num > 1) { for (auto partition_id = 1; partition_id < partition_num; partition_id++) { - if (partition_id * PARTITION_SIZE >= context_len) + if (partition_id * SEQ_PARTITION_SIZE >= context_len) break; auto tmp_out_start = tmp_out_ptr + seq_id * tmp_out_strideN + head_id * tmp_out_strideH + @@ -485,8 +487,8 @@ void paged_attention_kernel_impl( at::Tensor &block_tables, // [batch_size, max_num_blocks_per_seq] at::Tensor &context_lens, // [batch_size] c10::optional attn_mask) { - TORCH_CHECK(PARTITION_SIZE % key_cache.size(2) == 0, - "Paged attention: The PARTION_SIZE:%d should be divisible by block_size: %d", PARTITION_SIZE, key_cache.size(2)); + TORCH_CHECK(SEQ_PARTITION_SIZE % key_cache.size(2) == 0, + "Paged attention: The PARTION_SIZE:%d should be divisible by block_size: %d", SEQ_PARTITION_SIZE, key_cache.size(2)); TORCH_CHECK(query.size(2) == 1, "Paged attention: only seqlen 1 is supported for query"); TORCH_CHECK(query.scalar_type() == key_cache.scalar_type() && diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py index 94921c66a2..9f673c9ef7 100644 --- a/torchao/kv_cache.py +++ b/torchao/kv_cache.py @@ -10,9 +10,9 @@ class PagedTensor(object): def __init__( self, - cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, + cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations. + block_tables: torch.Tensor,#The block tables for each sequence in the batch which is used to mapping logical block to physical blocks. + context_lens: torch.Tensor,#The context lens for each sequence in the batch. ): self.block_tables = block_tables self.cache = cache From fcabfcea7fc5f7ae7fe0cda159466de7ca6fc6c3 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 24 Jun 2024 10:21:26 +0800 Subject: [PATCH 14/19] Update test_ops.py --- test/test_ops.py | 209 ----------------------------------------------- 1 file changed, 209 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a23a8b3066..2280591c9d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -112,215 +112,6 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): relative_error = error / results_fp16.abs() assert relative_error.mean() < 1e-2 - def _create_kv_caches( - self, - num_blocks: int, - block_size: int, - num_layer: int, - num_head: int, - head_size: int, - dtype: torch.dtype, - seed: int, - ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - torch.manual_seed(seed) - - scale = head_size**-0.5 - key_cache_shape = (num_blocks, block_size, num_head, head_size) - key_caches = [] - for _ in range(num_layer): - key_cache = torch.empty(size=key_cache_shape, dtype=dtype) - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, block_size, num_head, head_size) - value_caches = [] - for _ in range(num_layer): - value_cache = torch.empty(size=value_cache_shape, dtype=dtype) - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches - - def _ref_masked_attention( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - attn_weights = torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights * scale - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - def _ref_paged_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - scale: float, - ) -> None: - num_query_heads = query.shape[1] - num_kv_head = value_cache.shape[2] - head_size = value_cache.shape[3] - block_size = value_cache.shape[1] - num_seqs = query.shape[0] - - block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, block_offset, :, :] - k = k.reshape(num_kv_head, head_size) - keys.append(k) - - v = value_cache[block_number, block_offset, :, :] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - # out = self._ref_masked_attention(q, keys, values, scale, attn_mask[i]) - out = self._ref_masked_attention(q, keys, values, scale, None) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - def _test_paged_attention_func( - self, - num_seqs: int, - num_head: Tuple[int, int], - head_size: int, - num_blocks: int, - block_size: int, - dtype: torch.dtype, - seed: int, - ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - torch.manual_seed(seed) - max_seq_len = 512 - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_head = num_head - query = torch.empty( - num_seqs, num_query_heads, head_size, dtype=dtype, device="cpu" - ) - query.uniform_(-scale, scale) - assert num_query_heads % num_kv_head == 0 - num_queries_per_kv = num_query_heads // num_kv_head - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_head, dtype=torch.int32, device="cpu"), - num_queries_per_kv, - ) - attn_mask = None - context_lens = [random.randint(1, max_seq_len) for _ in range(num_seqs)] - context_lens[-1] = max_seq_len - max_context_len = max_seq_len # max(context_lens) - attn_mask = torch.zeros(num_seqs, 1, 1, max_context_len, dtype=dtype) - for i in range(num_seqs): - attn_mask[i, :, :, context_lens[i] :].fill_(-10000.0) - paded_context_lens = torch.tensor( - [max_context_len for _ in range(num_seqs)] - ).to(torch.int32) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cpu") - - # Create the block tables.NUM_PREFILL_SEQS - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, num_blocks - 1) for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cpu") - - # Create the KV caches. - key_caches, value_caches = self._create_kv_caches( - num_blocks, block_size, 1, num_kv_head, head_size, dtype, seed - ) - key_cache, value_cache = key_caches[0], value_caches[0] - - output = torch.empty_like(query) - torch.ops.torchao.paged_attention( - output, - query.unsqueeze(1), - key_cache, - value_cache, - head_mapping, - scale, - block_tables, - paded_context_lens, - block_size, - attn_mask, - ) - - # Run the reference implementation. - ref_output = torch.empty_like(query) - self._ref_paged_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - context_lens, - scale, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_paged_attention(self): - num_blocks = 128 - dtypes = [torch.bfloat16, torch.float, torch.float16] - num_gen_seqs = [2] # Arbitrary values for testing - num_heads = [(40, 40), (64, 16)] # Arbitrary values for testing - head_sizes = [64, 256] - block_sizes = [16, 32] - seeds = [0] - for ( - num_seqs, - num_head, - head_size, - block_size, - dtype, - seed, - ) in product( - num_gen_seqs, - num_heads, - head_sizes, - block_sizes, - dtypes, - seeds, - ): - self._test_paged_attention_func( - num_seqs, - num_head, - head_size, - num_blocks, - block_size, - dtype, - seed, - ) - - if __name__ == "__main__": unittest.main() From 11d4f47f683105c4412e4324e46960422c5b9308 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Mon, 24 Jun 2024 10:21:47 +0800 Subject: [PATCH 15/19] Update test_paged_attention.py --- test/kernel/test_paged_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index f9c7ba2a5d..5cedb4fc43 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -1,7 +1,6 @@ import torch import unittest import random -from typing import List, Optional, Tuple from itertools import product import torchao from torchao.kv_cache import PagedAttentionCache, PagedTensor From ad7caaa480cf2ee6a57a189fbfbb96595ea5e9fc Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Sun, 23 Jun 2024 19:24:01 -0700 Subject: [PATCH 16/19] Remove redundant test --- test/test_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 2280591c9d..c5f66926fb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,11 +1,9 @@ import torch from torch.testing._internal.common_utils import TestCase, IS_FBCODE from torch.testing._internal.optests import opcheck -from typing import List, Optional, Tuple import torchao from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 import unittest -import random from parameterized import parameterized import pytest @@ -13,6 +11,7 @@ import torchao.ops except RuntimeError: pytest.skip("torchao.ops not available") + # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): # test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) @@ -111,7 +110,7 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs() assert relative_error.mean() < 1e-2 - + if __name__ == "__main__": unittest.main() From 1daba57d326184f6dfe3f809cf6248a5fd11bd88 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Sun, 23 Jun 2024 19:25:13 -0700 Subject: [PATCH 17/19] update --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index c5f66926fb..1e3b2a5ee3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ import torchao.ops except RuntimeError: pytest.skip("torchao.ops not available") - + # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): # test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) From 93c5a7d1ff8e370e1239e7e054349fdb388174ff Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Sun, 23 Jun 2024 19:26:04 -0700 Subject: [PATCH 18/19] update format --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1e3b2a5ee3..b20e029380 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ import torchao.ops except RuntimeError: pytest.skip("torchao.ops not available") - + # torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): # test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) From 2e4c0486dc79fa6cb09032c67fdeed8f4fda0d49 Mon Sep 17 00:00:00 2001 From: "Zhang, Liangang" Date: Wed, 4 Sep 2024 01:48:11 -0700 Subject: [PATCH 19/19] Enable subclassing for paged attention design --- test/kernel/test_paged_attention.py | 10 +++---- torchao/kv_cache.py | 44 +++++++++++++++-------------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/test/kernel/test_paged_attention.py b/test/kernel/test_paged_attention.py index 5cedb4fc43..6892f1ec35 100644 --- a/test/kernel/test_paged_attention.py +++ b/test/kernel/test_paged_attention.py @@ -58,17 +58,17 @@ def forward(self, inputs, kv_cache): 1, 2 ) updated_key, updated_value = kv_cache.update(key, value, 0) - if isinstance(updated_key, torch.Tensor): + if not isinstance(updated_key, PagedTensor): updated_key = updated_key.repeat_interleave( self.num_heads // self.num_kv_heads, dim=1 ) updated_value = updated_value.repeat_interleave( self.num_heads // self.num_kv_heads, dim=1 ) - output = torch.nn.functional.scaled_dot_product_attention( + return torch.nn.functional.scaled_dot_product_attention( query, updated_key, updated_value, scale=self.scale ) - return output + @unittest.skipIf(torch.cuda.is_available(), "CUDA is not enabled yet") class PagedAttentionCachePagedTensorTest(unittest.TestCase): @@ -89,7 +89,7 @@ def _test_paged_attention_cache( mha_model = MHAModule(head_dim, num_query_heads, num_key_value_heads).to( device=device, dtype=dtype ) - naive_cache = NiaveCache() + naive_cache = NaiveCache() pagedcache = PagedAttentionCache( num_blocks, block_size, @@ -166,7 +166,7 @@ def test_paged_attention_kv_cache(self): num_key_value_heads = [40, 10, 1] head_dim = [64, 128] device = ['cpu'] - dtypes = [torch.float, torch.float16, torch.bfloat16] + dtypes = [torch.bfloat16, torch.float16] batch_size = [1, 8] beam_size = [1, 4] for ( diff --git a/torchao/kv_cache.py b/torchao/kv_cache.py index 9f673c9ef7..528e71f587 100644 --- a/torchao/kv_cache.py +++ b/torchao/kv_cache.py @@ -7,44 +7,46 @@ HANDLED_FUNCTIONS = {} -class PagedTensor(object): +class PagedTensor(torch.Tensor): + @staticmethod + def __new__(cls, size, cache, block_table, *args, **kwargs): + return torch.Tensor._make_wrapper_subclass(cls, size, dtype=cache.dtype, *args, **kwargs) + def __init__( self, - cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations. + size: Tuple[int, int, int, int],#The size of the cached tensor[bs, num_key_value_heads, seq_lens, head_dim]. + cache: torch.Tensor, #The cache tensor from the PagedAttentionCache object, which is shared accross iterations. block_tables: torch.Tensor,#The block tables for each sequence in the batch which is used to mapping logical block to physical blocks. - context_lens: torch.Tensor,#The context lens for each sequence in the batch. ): self.block_tables = block_tables - self.cache = cache - self.context_lens = context_lens + self.cache = cache def __repr__(self): - return f"PagedTensor({self.cache.shape})" + return f"PagedTensor(buffer shape: {self.cache.shape}, k/v cache shape:{self.shape}" + @staticmethod + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + return NotImplemented + @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} + def __torch_function__(cls, func, types, args, kwargs=None): if func not in HANDLED_FUNCTIONS or not all( - issubclass(t, (torch.Tensor, PagedTensor)) for t in types + issubclass(t, (torch.Tensor, PagedTensor)) + for t in types ): - return NotImplementedError( - "{} is not supported by PagedTensor".format(func) - ) + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) return HANDLED_FUNCTIONS[func](*args, **kwargs) - def implements(torch_function): - """Register a torch function override for PagedTensor""" - + """Register a torch function override for ScalarTensor""" def decorator(func): functools.update_wrapper(func, torch_function) HANDLED_FUNCTIONS[torch_function] = func return func - return decorator - @implements(torch.nn.functional.scaled_dot_product_attention) def scaled_dot_product_attention( input, key_tensor, value_tensor, attn_mask=None, scale=None @@ -53,7 +55,7 @@ def scaled_dot_product_attention( key_cache = key_tensor.cache value_cache = value_tensor.cache block_tables = key_tensor.block_tables - context_lens = key_tensor.context_lens + context_lens = torch.tensor([key_tensor.shape[2] for _ in range(key_tensor.shape[0])], dtype=torch.int32) output = torch.empty_like(query) torch.ops.torchao.paged_attention( output, @@ -382,8 +384,8 @@ def update( block_tables_t, dtype=torch.int32, device=self.device ) return PagedTensor( - self.key_caches[layer_idx], block_tables_t, context_lens - ), PagedTensor(self.value_caches[layer_idx], block_tables_t, context_lens) + (batch_size, self.num_key_value_heads, context_lens[0].item(), self.head_dim), self.key_caches[layer_idx], block_tables_t, + ), PagedTensor((batch_size, self.num_key_value_heads, context_lens[0].item(), self.head_dim), self.value_caches[layer_idx], block_tables_t) def reorder_cache(self, beam_idx: torch.Tensor) -> None: """