Skip to content

Commit 57eb0e1

Browse files
authored
Use GemmaAttention for Gemma (#72)
Change Gemma to use Gemma Attention from model_original This way it produces more accurate results
1 parent 811d718 commit 57eb0e1

File tree

9 files changed

+739
-169
lines changed

9 files changed

+739
-169
lines changed

convert_checkpoints.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def convert_hf_gemma_weights(
414414
ckpt_file = list(input_ckpt_dir.glob("*.ckpt"))
415415
assert len(ckpt_file) == 1, "only expect 1 ckpt file for Gemma model."
416416
ckpt_file = ckpt_file[0]
417-
state_dict = torch.load(ckpt_file, map_location=torch.device("cpu"))[
417+
state_dict = torch.load(str(ckpt_file), map_location=torch.device("cpu"))[
418418
"model_state_dict"
419419
]
420420
model_config = json.loads((input_ckpt_dir / "config.json").read_text())
@@ -447,8 +447,7 @@ def convert_hf_gemma_weights(
447447
state_dict[new_key.replace("qkv_proj", "wk")] = k
448448
state_dict[new_key.replace("qkv_proj", "wv")] = v
449449
continue
450-
if "o_proj" in key:
451-
new_key = new_key.replace("o_proj", "wo")
450+
452451
if new_key != key:
453452
state_dict[new_key] = state_dict.pop(key)
454453
_export_to_local(output_ckpt_dir, model_config, state_dict)

default_shardings/gemma.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Integer signify axis to shard: 0 <= shard axis < rank
55

66
freqs_cis : -1 # torch.complex64 (16384, 128)
7-
layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
7+
layers.*.self_attn.o_proj.weight: 1
88
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
99
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
1010
layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048)

docs/add_a_new_model.md

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
Add a new Model
2+
===============
3+
4+
This doc is a detailed guide on how to add a new model to jetstream-pytorch.
5+
The complexity of adding a new model depends highly on the model architecture itself,
6+
and right now is a manual process.
7+
8+
NOTE: Only LLMs that employ autoregressive decoding that utilices a KV cache are suitable
9+
for serving with Jetstream. Other models such as Stable Diffusion are NOT suitable
10+
with the optimization techniques used in Jetstream.
11+
12+
The core part of adding a model is to let Jetstream serving engine manage
13+
the KV cache. This management is abstracted by the [`class CacheInterface`](jetstream_pt/cache_manager.py). This interface has a single `update` method that will abstract
14+
the act of inserting and then reading the cache.
15+
16+
We will walk through this process using [Gemma model](https://github.com/google/gemma_pytorch) as an example.
17+
18+
# Step 0: Get the model code
19+
20+
Jetstream pytorch stores its models in the jetstream_pt/third_party directory.
21+
22+
The usual convention is:
23+
24+
1. Make a verbatim copy of the model code and supporting files
25+
(such as args class, tokenizers etc) in a separate directory. In our case
26+
it would be [jetstream_pt/third_party/gemma](jetstream_pt/third_party/gemma)
27+
28+
2. Make a copy of the `model.py` to `model_original.py`; because we will be modifying
29+
it to follow the conventions of Jetstream; and keeping the original can help with
30+
debugging accuracies (and unit tests).
31+
32+
*Optional:* Clean up model implementation: The easiest model to port are those of
33+
"reference implementations". Models already with optimizations and/or custom
34+
cuda kernels would need to have those changes removed.
35+
36+
In our case, we choose to use the reference Gemma model from google's github instead
37+
of the HuggingFace version, because HuggingFace version have also training code that
38+
would need to be removed.
39+
40+
# Step 1: Modify the model to fit the calling conventions expected by Jetstream.
41+
42+
The model that Jetstream expects and calls follows this calling convention:
43+
44+
```python
45+
class Model(torch.nn.Module):
46+
47+
def forward(
48+
self,
49+
tokens: torch.Tensor,
50+
input_pos: torch.Tensor,
51+
caches: List[CacheInterface],
52+
mask: torch.Tensor,
53+
) -> torch.Tensor:
54+
55+
```
56+
57+
The arguments are:
58+
59+
* `tokens`: A int tensor with shape (batch_size, sequence_length). This is the token ids
60+
before embedding
61+
62+
* `input_pos`: The position of the tokens in the overall sentence. This is an int
63+
tensor of shape (batch_size, sequence_length). Note: due to continues batching,
64+
not all batch have the same sequence length.
65+
66+
* `caches`: A list of objects implementing the `CacheInterface`. CacheInterface has a
67+
single `update` method.
68+
69+
* `mask`: Mask used in causal attention.
70+
71+
The return value should be a tensor of shape (batch_size, sequence_length, vocab_size)
72+
of **logits** (not probability) for the next token.
73+
74+
### Gemma example:
75+
76+
Now looking back to our Gemma model reference. There are 2 classes in the original
77+
model that is suitable to be our model [GemmaModel](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L353) and [GemmaForCausalLM](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L386). Looking at their forward method signature:
78+
79+
```python
80+
81+
class GemmaModel(nn.Module):
82+
def forward(
83+
self,
84+
hidden_states: torch.Tensor,
85+
freqs_cis: torch.Tensor,
86+
kv_write_indices: torch.Tensor,
87+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
88+
mask: torch.Tensor,
89+
) -> torch.Tensor:
90+
91+
class GemmaForCausalLM(nn.Module):
92+
@torch.no_grad()
93+
def forward(
94+
self,
95+
input_token_ids: torch.Tensor,
96+
input_positions: torch.Tensor,
97+
kv_write_indices: torch.Tensor,
98+
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
99+
mask: torch.Tensor,
100+
output_positions: torch.Tensor,
101+
temperatures: Union[torch.Tensor, None],
102+
top_ps: torch.Tensor,
103+
top_ks: torch.Tensor,
104+
**kwargs,
105+
) -> torch.Tensor:
106+
```
107+
108+
We can see that `GemmaModel` is probably closest to port. So we choose that one.
109+
However there are few issues:
110+
111+
1. GemmaModel takes `hidden_states` instead of tokens
112+
2. GemmaModel returns `hidden_states` after the layers and not logits.
113+
114+
Let's fix those first.
115+
116+
Looking at where `GemmaModel` is called in `model.py`, we found that:
117+
118+
```
119+
# [batch_size, input_len, hidden_size]
120+
hidden_states = self.embedder(input_token_ids)
121+
# Gemma normalizes the embedding by sqrt(hidden_size).
122+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
123+
```
124+
125+
So the input_tokens are embedded with `self.embedder` and processed before calling
126+
`GemmaModel`. So let's move these bit to inside of GemmaModel.
127+
128+
Now, look where the output of `GemmaModel` is consumed, we see it is feed to `self.sampler`.
129+
130+
`self.sampler` is of class `Sampler` and it's forward has:
131+
132+
```python
133+
hidden_states = hidden_states.index_select(
134+
1, output_positions).squeeze(dim=1)
135+
logits = torch.matmul(hidden_states, embedding.t())
136+
if embedding_bias is not None:
137+
logits += embedding_bias
138+
139+
if temperatures is None:
140+
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
141+
...
142+
```
143+
144+
We see it performed some math with hidden states to produce logits, which is what
145+
GemmaModel should return. Now, let move these bits into `GemmaModel` as well.
146+
147+
Lastly, GemmaModel takes a list of tuple of torch.Tensor as input for caches,
148+
we need to replace it with cache object.
149+
150+
This cache is plumbed through all the way to `GemmaAttention`, and the [following lines](https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L264C1-L268C53):
151+
152+
```python
153+
# Write new kv cache.
154+
# [batch_size, input_len, n_local_kv_heads, head_dim]
155+
k_cache, v_cache = kv_cache
156+
k_cache.index_copy_(1, kv_write_indices, xk)
157+
v_cache.index_copy_(1, kv_write_indices, xv)
158+
```
159+
160+
is precisely the cache update.
161+
So we need to replace those lines with
162+
163+
```
164+
xk = xk.transpose(1, 2)
165+
xv = xv.transpose(1, 2)
166+
k_cache, v_cache = cache.update(xk, xv)
167+
```
168+
169+
The transpose is needed because the cache interface's `update` method expects
170+
shape of (batch, num_heads, sequence length, head dim) instead of
171+
(batch, sequence length, num_heads, head dim) that GemmaAttention produces.
172+
173+
174+
In our case, because the Attention math is the standard one, we can just call out
175+
to `AttentionKernel` defined in [layers.py](jetstream_pt/layers.py). `AttentionKernel`
176+
also handles reading and writing of `cache` automatically.
177+
178+
At this point, the model should be runnable. However to run it on a realistic TPU,
179+
we need to add model parallelism.
180+
181+
# Step 2: Add model parallelism
182+
183+
Model parallelism is often neccesary to run on TPUs. The typical setup for running
184+
inference work loads is by using TPU `v5light-8` which has 8 TPU chips with 16GB of
185+
high bandwidth memory (HBM) each. The typical `7B` model won't fit on single chip.
186+
187+
So we need to add model parallelism so the model weights are sharded among the 8 devices.
188+
This is necesary for larger models, such as 70Bs even on high memory chips (v5p).
189+
So it's a good practice to do it right away.
190+
191+
Jetstream uses GSMPD to for tensor parallelism, the only information we need to
192+
give it is, for every tensor weights, what axis we will shard. We do so by writing
193+
a sharding config file.
194+
195+
## Generate an sharding config:
196+
197+
The keys of the sharding file is the name of the weights, (with numeric layers replaced with *),
198+
and value the axis to shard.
199+
for Gemma, we can generate such file by printing out the keys in it's `state_dict`.
200+
See [create_empty_sharding_map.py](scripts/create_empty_sharding_map.py) for example.
201+
202+
Below:
203+
204+
```yaml
205+
freqs_cis : -1 # torch.complex64 (16384, 128)
206+
layers.*.self_attn.qkv_proj.weight: 0
207+
layers.*.self_attn.o_proj.weight: 1
208+
layers.*.self_attn.wo.weight : 1 # 1, -1] # torch.float32 (2048, 2048)
209+
layers.*.self_attn.wq.weight : 0 # -1, 1] # torch.float32 (2048, 2048)
210+
layers.*.self_attn.wk.weight : 0 # -1, 1] # torch.float32 (256, 2048)
211+
layers.*.self_attn.wv.weight : 0 # -1, 1] # torch.float32 (256, 2048)
212+
layers.*.mlp.gate_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
213+
layers.*.mlp.gate_proj.bias : 0 # -1] # torch.float32 (16384,)
214+
layers.*.mlp.up_proj.weight : 0 # -1, 1] # torch.float32 (16384, 2048)
215+
layers.*.mlp.up_proj.bias : 0 # -1] # torch.float32 (16384,)
216+
layers.*.mlp.down_proj.weight : 1 # 1, -1] # torch.float32 (2048, 16384)
217+
layers.*.mlp.down_proj.bias : -1 # torch.float32 (2048,)
218+
layers.*.input_layernorm.weight : -1 # torch.float32 (2048,)
219+
layers.*.post_attention_layernorm.weight : -1 # torch.float32 (2048,)
220+
norm.weight : -1 # torch.float32 (2048,)
221+
embedder.weight : 1 # # 1, -1] # torch.float32 (256000, 2048)
222+
```
223+
224+
the weights `layers.*.self_attn.qkv_proj.weight` where * goes for 1..28, are sharded
225+
on the second dimension (0 based indexing) etc. and -1 means "replicated".
226+
227+
Theoretically, any valid sharding would work. To find a sharding that performs well one
228+
can usually get some hints from the original model implementation.
229+
230+
For example, in case of Gemma, the authors also provided an TPU version: https://github.com/google/gemma_pytorch/blob/main/gemma/model_xla.py
231+
232+
in that file, those with `ColumnParallelLinear` should be sharded on the dimension 0,
233+
and with `RowParallelLinear` should be shard on dimension 1; the others should be
234+
replicated.
235+
236+
# Step 3: Activation Sharding and quantization
237+
238+
Sometimes we would like to specify shardings for the activation because GSPMD cannot
239+
fully infer all the shardings.
240+
241+
The typical example of such case happens after a reshape. For example: if I have a matrix
242+
of shape [A, B * C]; and the second dim is sharded; reshaping it to shape [A, B, C],
243+
the compiler would know that one of the dim B or C is sharded, but cannot know which one.
244+
In this case, it is helpful to specify with a sharding constraint.
245+
246+
This is done by calling `env.apply_sharding(tensor, axis=1)` on the tensor.
247+
248+
The `env` object is an instance of `Environment` class; that will be passed in the
249+
model constructor. It also contains some common configurations (such as whether user wants quantization), that is useful for the models.
250+
251+
For such, we the store that variable in `self.env` and use it when needed.
252+
253+
For quantization, it suffices to swap `nn.Linear` layers with `Int8QuantizedLinear` defined in
254+
`layers.py`
255+
256+
# Step 4: Wiring everything up.
257+
258+
The last step is to modify [engine.py](https://github.com/google/jetstream-pytorch/blob/main/jetstream_pt/engine.py#L738)
259+
and add an if branch in this function.
260+
261+
This function should receive information about model name and size; and
262+
here it should instantate the model object itself. It also need to tell the environment
263+
information about the cache to allocate: notably how many layers and the shape of
264+
cache. The shape is expected to be (batch size, num_kv_heads, sequence_length, head_dim).
265+
266+
## Test it out
267+
268+
After these steps you should be able to run your model using
269+
270+
```bash
271+
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml --model=gemma
272+
```
273+
274+
If you run it without checkpoint_path it will use random weights, so you can
275+
verify that the code actually run.
276+
277+
# Step 5: Weight convertion
278+
279+
Because we modified the model, and the names of variables on the model might have
280+
changed. If so, we need to also modify `convert_weights.py` script to map
281+
the original weights to modified names.
282+
283+
For example: I split qkv projection to 3 separate projection, this helps with
284+
performance in a sharded environment. So I need to make `convert_weights` script
285+
able to split the weights as well.
286+

install_everything.sh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ TORCHXLA_TAG=jetstream-pytorch
1616
JETSTREAM_TAG=v0.2.1
1717

1818
# Uninstall existing jax
19-
pip3 show jax && pip3 uninstall -y jax
20-
pip3 show jaxlib && pip3 uninstall -y jaxlib
21-
pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly
19+
pip show jax && pip uninstall -y jax
20+
pip show jaxlib && pip uninstall -y jaxlib
21+
pip show libtpu-nightly && pip uninstall -y libtpu-nightly
2222

23-
pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
23+
pip install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2424
# torch cpu
25-
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
26-
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
27-
pip3 install safetensors colorama coverage ray[default] humanize
25+
pip install torch --index-url https://download.pytorch.org/whl/cpu
26+
pip install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
27+
pip install safetensors colorama coverage ray[default] humanize
2828

2929
mkdir -p deps
3030
pushd deps

jetstream_pt/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, List, Optional, Tuple, Union
1818
import threading
1919
import functools
20+
import os
2021

2122
from etils import epath
2223
from flax import struct
@@ -703,6 +704,9 @@ def create_pytorch_engine(
703704
tokenizer = token_utils.load_vocab(tokenizer_path)
704705
pt_model = None
705706

707+
if not sharding_config:
708+
sharding_config = os.path.join("default_shardings", model_name + ".yaml")
709+
706710
env_data = JetEngineEnvironmentData(
707711
tokenizer_path=tokenizer_path,
708712
checkpoint_path=checkpoint_path,

0 commit comments

Comments
 (0)