|
| 1 | +import queue |
| 2 | +import functools |
| 3 | +from typing import List, Tuple |
| 4 | + |
| 5 | +import jax |
| 6 | +import jax.numpy as jnp |
| 7 | +import jax.sharding as jsharding |
| 8 | + |
| 9 | + |
| 10 | +class PageAttentionManager: |
| 11 | + """Manages page blocks. |
| 12 | +
|
| 13 | + This manager maintains a main list of free page blocks, it support below features: |
| 14 | + 1. Reseve pages for prefill insert and decode. |
| 15 | + 2. Free pages resource for the slots after decode. Pages indices go to free list. |
| 16 | + 3. Get pages indices meta data for all the slots. |
| 17 | + 4. Transform and insert prefill caches to decode caches. |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + batch_size: int, |
| 23 | + total_num_pages: int, |
| 24 | + page_size: int, |
| 25 | + max_pages_per_sequence: int, |
| 26 | + ): |
| 27 | + self.unused_pages = queue.Queue() |
| 28 | + self.batch_size = batch_size |
| 29 | + self.page_indices = jnp.full( |
| 30 | + (batch_size, max_pages_per_sequence), -1, dtype=jnp.int32 |
| 31 | + ) |
| 32 | + self.lengths = jnp.zeros(batch_size, dtype=jnp.int32) |
| 33 | + self.page_size = page_size |
| 34 | + self.max_pages_per_sequence = max_pages_per_sequence |
| 35 | + for i in range(total_num_pages): |
| 36 | + self.unused_pages.put(i, block=False) |
| 37 | + |
| 38 | + # pylint: disable-next=all |
| 39 | + def reserve_pages_insert(self, slot: int, seq_len: int) -> Tuple[int, list]: |
| 40 | + self.lengths = self.lengths.at[slot].set(seq_len) |
| 41 | + num_pages = seq_len // self.page_size |
| 42 | + if seq_len % self.page_size != 0: |
| 43 | + num_pages = num_pages + 1 |
| 44 | + |
| 45 | + indices = [self.unused_pages.get(block=False) for _ in range(num_pages)] |
| 46 | + self.page_indices = self.page_indices.at[slot, :num_pages].set(indices) |
| 47 | + return num_pages |
| 48 | + |
| 49 | + # pylint: disable-next=all |
| 50 | + def reserve_pages_decode(self, slot: int, seq_len: int): |
| 51 | + if seq_len > 0 and seq_len % self.page_size == 0: |
| 52 | + index = self.unused_pages.get(block=False) |
| 53 | + num_pages = seq_len // self.page_size |
| 54 | + self.page_indices = self.page_indices.at[slot, num_pages].set(index) |
| 55 | + |
| 56 | + # pylint: disable-next=all |
| 57 | + def prefill_cache_padding( |
| 58 | + self, |
| 59 | + caches: List[Tuple[jax.Array, jax.Array]], |
| 60 | + seq_len: int, |
| 61 | + num_pages: int, |
| 62 | + ) -> List[Tuple[jax.Array, jax.Array]]: |
| 63 | + |
| 64 | + pad_width = num_pages * self.page_size - seq_len |
| 65 | + if pad_width == 0: |
| 66 | + return caches |
| 67 | + |
| 68 | + caches = [ |
| 69 | + (self.pad_sequences(k, pad_width), self.pad_sequences(v, pad_width)) |
| 70 | + for k, v in caches |
| 71 | + ] |
| 72 | + return caches |
| 73 | + |
| 74 | + def insert_prefill_cache( |
| 75 | + self, |
| 76 | + prefill_caches: List[Tuple[jax.Array, jax.Array]], |
| 77 | + decode_caches: List[Tuple[jax.Array, jax.Array]], |
| 78 | + slot: int, |
| 79 | + seq_len: int, |
| 80 | + sharding: jsharding.Sharding, |
| 81 | + ) -> List[Tuple[jax.Array, jax.Array]]: |
| 82 | + """Insert prefill caches to decode caches slot. |
| 83 | +
|
| 84 | + Args: |
| 85 | + prefill_caches: List of Tuple K, V. For each K, V: |
| 86 | + [batch_size, num_heads, seq_len, head_dim] jax.Array. |
| 87 | + decode_caches: List of Tuple K, V. For each K, V: |
| 88 | + [num_heads, total_num_pages, page_size, head_dim] jax.Array. |
| 89 | + slot: Slot of batch size in decode. |
| 90 | + seq_len: Prefill tokens seqeunce length. |
| 91 | + sharding: Decode cache sharding. |
| 92 | +
|
| 93 | +
|
| 94 | + Returns: |
| 95 | + Decode cache. List of Tuple K, V. For each K, V: |
| 96 | + [num_heads, total_num_pages, page_size, head_dim] jax.Array. |
| 97 | + """ |
| 98 | + |
| 99 | + num_pages = self.reserve_pages_insert(slot, seq_len) |
| 100 | + padded_caches = self.prefill_cache_padding( |
| 101 | + prefill_caches, seq_len, num_pages |
| 102 | + ) |
| 103 | + # Reduce cache batch deminsion |
| 104 | + # [kv_heads, seq_len, dim] |
| 105 | + squeezed_caches = [ |
| 106 | + (jnp.squeeze(k, axis=0), jnp.squeeze(v, axis=0)) |
| 107 | + for k, v in padded_caches |
| 108 | + ] |
| 109 | + kv_heads, _, dim = squeezed_caches[0][0].shape |
| 110 | + # [kv_heads, num_pages, page_size, dim] |
| 111 | + paged_caches = [ |
| 112 | + ( |
| 113 | + jnp.reshape(k, (kv_heads, -1, self.page_size, dim)), |
| 114 | + jnp.reshape(v, (kv_heads, -1, self.page_size, dim)), |
| 115 | + ) |
| 116 | + for k, v in squeezed_caches |
| 117 | + ] |
| 118 | + update_indexes = self.page_indices[slot, :num_pages] |
| 119 | + |
| 120 | + @functools.partial(jax.jit, donate_argnums=(0, 1), inline=True) |
| 121 | + def insert(cache, new_entry): |
| 122 | + new_entry = new_entry.squeeze(0) |
| 123 | + res = cache.at[:, update_indexes, :, :].set(new_entry) |
| 124 | + res = jax.lax.with_sharding_constraint(res, sharding) |
| 125 | + return res |
| 126 | + |
| 127 | + caches = [ |
| 128 | + (insert(k, newk), insert(v, newv)) |
| 129 | + for (k, v), (newk, newv) in zip(decode_caches, paged_caches) |
| 130 | + ] |
| 131 | + |
| 132 | + return caches |
| 133 | + |
| 134 | + # pylint: disable-next=all |
| 135 | + def get_page_token_indices(self, lens: jax.Array) -> jax.Array: |
| 136 | + |
| 137 | + assert lens.shape == ( |
| 138 | + self.batch_size, |
| 139 | + 1, |
| 140 | + ), f"len shape: {lens.shape} not equals batch size: {self.batch_size, 1}" |
| 141 | + update_page_indices = [] |
| 142 | + token_scale_indices = [] |
| 143 | + batch_slots = [] |
| 144 | + new_lens = [] |
| 145 | + offset = 0 |
| 146 | + for slot in range(self.batch_size): |
| 147 | + seq_len = lens[slot][0] |
| 148 | + num_pages = seq_len // self.page_size + 1 |
| 149 | + token_pos = seq_len % self.page_size |
| 150 | + page_index = self.page_indices[slot, num_pages - 1] |
| 151 | + if page_index < 0: |
| 152 | + continue |
| 153 | + update_page_indices.append(page_index) |
| 154 | + token_scale_indices.append(offset + token_pos) |
| 155 | + batch_slots.append(slot) |
| 156 | + new_lens.append(seq_len + 1) |
| 157 | + offset += self.page_size |
| 158 | + return jnp.stack( |
| 159 | + ( |
| 160 | + jnp.asarray(update_page_indices), |
| 161 | + jnp.asarray(token_scale_indices), |
| 162 | + jnp.asarray(batch_slots), |
| 163 | + jnp.asarray(new_lens), |
| 164 | + ) |
| 165 | + ) |
| 166 | + |
| 167 | + # pylint: disable-next=all |
| 168 | + def fill_new_pages(self, lens: jax.Array): |
| 169 | + for slot in range(self.batch_size): |
| 170 | + self.reserve_pages_decode(slot, lens[slot]) |
| 171 | + |
| 172 | + # pylint: disable-next=all |
| 173 | + def pad_sequences(self, array, pad_width=10): |
| 174 | + padding_config = [ |
| 175 | + (0, 0), |
| 176 | + (0, 0), |
| 177 | + (0, pad_width), |
| 178 | + (0, 0), |
| 179 | + ] # Pad only seq_len and dim |
| 180 | + padded_array = jnp.pad(array, padding_config, mode="constant") |
| 181 | + return padded_array |
| 182 | + |
| 183 | + # pylint: disable-next=all |
| 184 | + def free_pages_resource(self, slot): |
| 185 | + for i in range(self.max_pages_per_sequence): |
| 186 | + index = self.page_indices[slot, i] |
| 187 | + if index < 0: |
| 188 | + break |
| 189 | + self.unused_pages.put(index, block=False) |
| 190 | + |
| 191 | + self.page_indices = self.page_indices.at[slot, :].set(jnp.asarray([-1])) |
| 192 | + return None |
0 commit comments