|
| 1 | +Add a new Model |
| 2 | +=============== |
| 3 | + |
| 4 | +This doc is a detailed guide on how to add a new model to jetstream-pytorch. |
| 5 | +The complexity of adding a new model depends highly on the model architecture itself, |
| 6 | +and right now is a manual process. |
| 7 | + |
| 8 | +NOTE: Only LLMs that employ autoregressive decoding that utilices a KV cache are suitable |
| 9 | +for serving with Jetstream. Other models such as Stable Diffusion are NOT suitable |
| 10 | +with the optimization techniques used in Jetstream. |
| 11 | + |
| 12 | +The core part of adding a model is to let Jetstream serving engine manage |
| 13 | +the KV cache. This management is abstracted by the [`class CacheInterface`](jetstream_pt/cache_manager.py). This interface has a single `update` method that will abstract |
| 14 | +the act of inserting and then reading the cache. |
| 15 | + |
| 16 | +We will walk through this process using [Gemma model](https://github.com/google/gemma_pytorch) as an example. |
| 17 | + |
| 18 | +# Step 0: Get the model code |
| 19 | + |
| 20 | +Jetstream pytorch stores its models in the jetstream_pt/third_party directory. |
| 21 | + |
| 22 | +The usual convention is: |
| 23 | + |
| 24 | +1. Make a verbatim copy of the model code and supporting files |
| 25 | + (such as args class, tokenizers etc) in a separate directory. In our case |
| 26 | + it would be [jetstream_pt/third_party/gemma](jetstream_pt/third_party/gemma) |
| 27 | + |
| 28 | +2. Make a copy of the `model.py` to `model_original.py`; because we will be modifying |
| 29 | + it to follow the conventions of Jetstream; and keeping the original can help with |
| 30 | + debugging accuracies (and unit tests). |
| 31 | + |
| 32 | +*Optional:* Clean up model implementation: The easiest model to port are those of |
| 33 | + "reference implementations". Models already with optimizations and/or custom |
| 34 | + cuda kernels would need to have those changes removed. |
| 35 | + |
| 36 | +In our case, we choose to use the reference Gemma model from google's github instead |
| 37 | +of the HuggingFace version, because HuggingFace version have also training code that |
| 38 | +would need to be removed. |
| 39 | + |
| 40 | +# Step 1: Modify the model to fit the calling conventions expected by Jetstream. |
| 41 | + |
| 42 | +The model that Jetstream expects and calls follows this calling convention: |
| 43 | + |
| 44 | +```python |
| 45 | +class Model(torch.nn.Module): |
| 46 | + |
| 47 | + def forward( |
| 48 | + self, |
| 49 | + tokens: torch.Tensor, |
| 50 | + input_pos: torch.Tensor, |
| 51 | + caches: List[CacheInterface], |
| 52 | + mask: torch.Tensor, |
| 53 | + ) -> torch.Tensor: |
| 54 | + |
| 55 | +``` |
| 56 | + |
| 57 | +The arguments are: |
| 58 | + |
| 59 | +* `tokens`: A int tensor with shape (batch_size, sequence_length). This is the token ids |
| 60 | + before embedding |
| 61 | + |
| 62 | +* `input_pos`: The position of the tokens in the overall sentence. This is an int |
| 63 | + tensor of shape (batch_size, sequence_length). Note: due to continues batching, |
| 64 | + not all batch have the same sequence length. |
| 65 | + |
| 66 | +* `caches`: A list of objects implementing the `CacheInterface`. CacheInterface has a |
| 67 | + single `update` method. |
| 68 | + |
| 69 | +* `mask`: Mask used in causal attention. |
| 70 | + |
| 71 | +The return value should be a tensor of shape (batch_size, sequence_length, vocab_size) |
| 72 | +of **logits** (not probability) for the next token. |
| 73 | + |
| 74 | +### Gemma example: |
| 75 | + |
| 76 | +Now looking back to our Gemma model reference. There are 2 classes in the original |
| 77 | +model that is suitable to be our model [GemmaModel](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L353) and [GemmaForCausalLM](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L386). Looking at their forward method signature: |
| 78 | + |
| 79 | +```python |
| 80 | + |
| 81 | +class GemmaModel(nn.Module): |
| 82 | + def forward( |
| 83 | + self, |
| 84 | + hidden_states: torch.Tensor, |
| 85 | + freqs_cis: torch.Tensor, |
| 86 | + kv_write_indices: torch.Tensor, |
| 87 | + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], |
| 88 | + mask: torch.Tensor, |
| 89 | + ) -> torch.Tensor: |
| 90 | + |
| 91 | +class GemmaForCausalLM(nn.Module): |
| 92 | + @torch.no_grad() |
| 93 | + def forward( |
| 94 | + self, |
| 95 | + input_token_ids: torch.Tensor, |
| 96 | + input_positions: torch.Tensor, |
| 97 | + kv_write_indices: torch.Tensor, |
| 98 | + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], |
| 99 | + mask: torch.Tensor, |
| 100 | + output_positions: torch.Tensor, |
| 101 | + temperatures: Union[torch.Tensor, None], |
| 102 | + top_ps: torch.Tensor, |
| 103 | + top_ks: torch.Tensor, |
| 104 | + **kwargs, |
| 105 | + ) -> torch.Tensor: |
| 106 | +``` |
| 107 | + |
| 108 | +We can see that `GemmaModel` is probably closest to port. So we choose that one. |
| 109 | +However there are few issues: |
| 110 | + |
| 111 | +1. GemmaModel takes `hidden_states` instead of tokens |
| 112 | +2. GemmaModel returns `hidden_states` after the layers and not logits. |
| 113 | + |
| 114 | +Let's fix those first. |
| 115 | + |
| 116 | +Looking at where `GemmaModel` is called in `model.py`, we found that: |
| 117 | + |
| 118 | +``` |
| 119 | + # [batch_size, input_len, hidden_size] |
| 120 | + hidden_states = self.embedder(input_token_ids) |
| 121 | + # Gemma normalizes the embedding by sqrt(hidden_size). |
| 122 | + hidden_states = hidden_states * (self.config.hidden_size**0.5) |
| 123 | +``` |
| 124 | + |
| 125 | +So the input_tokens are embedded with `self.embedder` and processed before calling |
| 126 | +`GemmaModel`. So let's move these bit to inside of GemmaModel. |
| 127 | + |
| 128 | +Now, look where the output of `GemmaModel` is consumed, we see it is feed to `self.sampler`. |
| 129 | + |
| 130 | +`self.sampler` is of class `Sampler` and it's forward has: |
| 131 | + |
| 132 | +```python |
| 133 | + hidden_states = hidden_states.index_select( |
| 134 | + 1, output_positions).squeeze(dim=1) |
| 135 | + logits = torch.matmul(hidden_states, embedding.t()) |
| 136 | + if embedding_bias is not None: |
| 137 | + logits += embedding_bias |
| 138 | + |
| 139 | + if temperatures is None: |
| 140 | + return torch.argmax(logits, dim=-1).squeeze(dim=-1) |
| 141 | + ... |
| 142 | +``` |
| 143 | + |
| 144 | +We see it performed some math with hidden states to produce logits, which is what |
| 145 | +GemmaModel should return. Now, let move these bits into `GemmaModel` as well. |
| 146 | + |
| 147 | +Lastly, GemmaModel takes a list of tuple of torch.Tensor as input for caches, |
| 148 | +we need to replace it with cache object. |
| 149 | + |
| 150 | +This cache is plumbed through all the way to `GemmaAttention`, and the [following lines](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L264C1-L268C53): |
| 151 | + |
| 152 | +```python |
| 153 | + # Write new kv cache. |
| 154 | + # [batch_size, input_len, n_local_kv_heads, head_dim] |
| 155 | + k_cache, v_cache = kv_cache |
| 156 | + k_cache.index_copy_(1, kv_write_indices, xk) |
| 157 | + v_cache.index_copy_(1, kv_write_indices, xv) |
| 158 | +``` |
| 159 | + |
| 160 | +is precisely the cache update. |
| 161 | +So we need to replace those lines with |
| 162 | + |
| 163 | +``` |
| 164 | +xk = xk.transpose(1, 2) |
| 165 | +xv = xv.transpose(1, 2) |
| 166 | +k_cache, v_cache = cache.update(xk, xv) |
| 167 | +``` |
| 168 | + |
| 169 | +The transpose is needed because the cache interface's `update` method expects |
| 170 | +shape of (batch, num_heads, sequence length, head dim) instead of |
| 171 | + (batch, sequence length, num_heads, head dim) that GemmaAttention produces. |
| 172 | + |
| 173 | + |
| 174 | +In our case, because the Attention math is the standard one, we can just call out |
| 175 | +to `AttentionKernel` defined in [layers.py](jetstream_pt/layers.py). `AttentionKernel` |
| 176 | +also handles reading and writing of `cache` automatically. |
| 177 | + |
| 178 | +At this point, the model should be runnable. However to run it on a realistic TPU, |
| 179 | +we need to add model parallelism. |
| 180 | + |
| 181 | +# Step 2: Add model parallelism |
| 182 | + |
| 183 | +Model parallelism is often neccesary to run on TPUs. The typical setup for running |
| 184 | +inference work loads is by using TPU `v5light-8` which has 8 TPU chips with 16GB of |
| 185 | +high bandwidth memory (HBM) each. The typical `7B` model won't fit on single chip. |
| 186 | + |
| 187 | +So we need to add model parallelism so the model weights are sharded among the 8 devices. |
| 188 | +This is necesary for larger models, such as 70Bs even on high memory chips (v5p). |
| 189 | +So it's a good practice to do it right away. |
| 190 | + |
| 191 | +Jetstream uses GSMPD to for tensor parallelism, the only information we need to |
| 192 | +give it is, for every tensor weights, what axis we will shard. We do so by writing |
| 193 | +a sharding config file. |
| 194 | + |
| 195 | +## Generate an sharding config: |
| 196 | + |
| 197 | +The keys of the sharding file is the name of the weights, (with numeric layers replaced with *), |
| 198 | +and value the axis to shard. |
| 199 | +for Gemma, we can generate such file by printing out the keys in it's `state_dict`. |
| 200 | +See [create_empty_sharding_map.py](scripts/create_empty_sharding_map.py) for example. |
| 201 | + |
| 202 | +Below: |
| 203 | + |
| 204 | +```yaml |
| 205 | +freqs_cis : -1 # torch.complex64 (16384, 128) |
| 206 | +layers.*.self_attn.qkv_proj.weight: 0 |
| 207 | +layers.*.self_attn.o_proj.weight: 1 |
| 208 | +layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048) |
| 209 | +layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048) |
| 210 | +layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048) |
| 211 | +layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048) |
| 212 | +layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) |
| 213 | +layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,) |
| 214 | +layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048) |
| 215 | +layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,) |
| 216 | +layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384) |
| 217 | +layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,) |
| 218 | +layers.*.input_layernorm.weight : -1 # torch.float32 (2048,) |
| 219 | +layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,) |
| 220 | +norm.weight : -1 # torch.float32 (2048,) |
| 221 | +embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048) |
| 222 | +``` |
| 223 | +
|
| 224 | +the weights `layers.*.self_attn.qkv_proj.weight` where * goes for 1..28, are sharded |
| 225 | +on the second dimension (0 based indexing) etc. and -1 means "replicated". |
| 226 | + |
| 227 | +Theoretically, any valid sharding would work. To find a sharding that performs well one |
| 228 | +can usually get some hints from the original model implementation. |
| 229 | + |
| 230 | +For example, in case of Gemma, the authors also provided an TPU version: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py |
| 231 | + |
| 232 | +in that file, those with `ColumnParallelLinear` should be sharded on the dimension 0, |
| 233 | +and with `RowParallelLinear` should be shard on dimension 1; the others should be |
| 234 | +replicated. |
| 235 | + |
| 236 | +# Step 3: Activation Sharding and quantization |
| 237 | + |
| 238 | +Sometimes we would like to specify shardings for the activation because GSPMD cannot |
| 239 | +fully infer all the shardings. |
| 240 | + |
| 241 | +The typical example of such case happens after a reshape. For example: if I have a matrix |
| 242 | +of shape [A, B * C]; and the second dim is sharded; reshaping it to shape [A, B, C], |
| 243 | +the compiler would know that one of the dim B or C is sharded, but cannot know which one. |
| 244 | +In this case, it is helpful to specify with a sharding constraint. |
| 245 | + |
| 246 | +This is done by calling `env.apply_sharding(tensor, axis=1)` on the tensor. |
| 247 | + |
| 248 | +The `env` object is an instance of `Environment` class; that will be passed in the |
| 249 | +model constructor. It also contains some common configurations (such as whether user wants quantization), that is useful for the models. |
| 250 | + |
| 251 | +For such, we the store that variable in `self.env` and use it when needed. |
| 252 | + |
| 253 | +For quantization, it suffices to swap `nn.Linear` layers with `Int8QuantizedLinear` defined in |
| 254 | +`layers.py` |
| 255 | + |
| 256 | +# Step 4: Wiring everything up. |
| 257 | + |
| 258 | +The last step is to modify [engine.py](https://github.com/google/jetstream-pytorch/blob/main/jetstream_pt/engine.py#L738) |
| 259 | +and add an if branch in this function. |
| 260 | + |
| 261 | +This function should receive information about model name and size; and |
| 262 | +here it should instantate the model object itself. It also need to tell the environment |
| 263 | +information about the cache to allocate: notably how many layers and the shape of |
| 264 | +cache. The shape is expected to be (batch size, num_kv_heads, sequence_length, head_dim). |
| 265 | + |
| 266 | +## Test it out |
| 267 | + |
| 268 | +After these steps you should be able to run your model using |
| 269 | + |
| 270 | +```bash |
| 271 | +python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml --model=gemma |
| 272 | +``` |
| 273 | + |
| 274 | +If you run it without checkpoint_path it will use random weights, so you can |
| 275 | +verify that the code actually run. |
| 276 | + |
| 277 | +# Step 5: Weight convertion |
| 278 | + |
| 279 | +Because we modified the model, and the names of variables on the model might have |
| 280 | +changed. If so, we need to also modify `convert_weights.py` script to map |
| 281 | +the original weights to modified names. |
| 282 | + |
| 283 | +For example: I split qkv projection to 3 separate projection, this helps with |
| 284 | +performance in a sharded environment. So I need to make `convert_weights` script |
| 285 | +able to split the weights as well. |
| 286 | + |
0 commit comments