Skip to content

Commit e07aee6

Browse files
authored
Add guide on adding HF ckpt conversion support (#119)
add doc
1 parent 94b576c commit e07aee6

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

docs/add_hf_checkpoint_conversion.md

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Guide on adding HuggingFace checkpoint conversion support
2+
3+
## Prerequisites:
4+
The model implementation has been added in JetStream-pt
5+
The checkpoint conversion from a certain format is already supported. (Or no conversion is needed for the checkpoint)
6+
7+
Please check this [guide](https://github.com/google/jetstream-pytorch/blob/main/docs/add_a_new_model.md) for adding a new model.
8+
9+
## Use case:
10+
The user has the checkpoint for the same model architecture in another format (e.g. HF format for LLaMA model). And want to have JetStream-pt support this checkpoint format.
11+
12+
## Guide
13+
14+
Converting a public checkpoint to JetStream-pt format is mostly about finding the weight key mapping between the public checkpoint and JetStream model implementation. Besides the name mapping, the layout of the weights might be different among different checkpoint formats (e.g. Weight interleaved differently due to difference in Rotary Embedding implementation). These differences are model and checkpoint format specific.
15+
16+
**Note** The model code and checkpoint format can be different from model to model, the following guide demonstrate a general guide, specific models may require additional effort for the checkpoint conversion support.
17+
18+
The checkpoint conversion logic in the checkpoint conversion script.
19+
20+
### Step 1 Find the HuggingFace checkpoint you want to convert
21+
In this example, let’s use meta-llama/llama-2 7B as an example
22+
23+
You can download the checkpoints to a local folder using
24+
huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir Llama-2-7b-hf
25+
26+
27+
**Note** You may need to go to Huggingface website to sign an agreement to get the permission to download the model
28+
29+
### Step 2 Inspect the weight names in the checkpoint:
30+
31+
Usually there is a model.safetensors.index.json file in the checkpoint. [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/model.safetensors.index.json)
32+
33+
Alternatively, you can load the weights locally and inspect the model key names(Usually it’s in safetensor format, and it’s sharded)
34+
35+
Example script:
36+
```Python
37+
import glob
38+
import os
39+
import torch
40+
from safetensors import safe_open
41+
42+
checkpoint_folder = "/mnt/disks/lsiyuan/llama_weight/Meta-Llama-3-8B-Instruct"
43+
44+
safetensor_files = glob.glob(os.path.join(checkpoint_folder, "*.safetensors"))
45+
46+
for st_f in safetensor_files:
47+
with safe_open(st_f, framework="pt", device="cpu") as f:
48+
for key in f.keys():
49+
weight_tensor = f.get_tensor(key)
50+
print(f"Weight name {key}, Shape: {weight_tensor.shape}, dtype: {weight_tensor.dtype}")
51+
```
52+
53+
Got the following output:
54+
55+
```
56+
lm_head.weight torch.Size([32000, 4096]) x torch.float16
57+
model.norm.weight torch.Size([4096]) x torch.float16
58+
model.embed_tokens.weight torch.Size([32000, 4096]) x torch.float16
59+
model.layers.0.input_layernorm.weight torch.Size([4096]) x torch.float16
60+
model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) x torch.float16
61+
model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) x torch.float16
62+
model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) x torch.float16
63+
model.layers.0.post_attention_layernorm.weight torch.Size([4096]) x torch.float16
64+
model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) x torch.float16
65+
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) x torch.float16
66+
model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) x torch.float16
67+
model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) x torch.float32
68+
model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) x torch.float16
69+
… # Duplicated name for model.layers.x
70+
```
71+
72+
If it’s hard to tell which layer the weight is for, the HF model class can be checked in the checkpoint config file [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L4). Then we can find the model code in the transformer repo by searching the model class name [model code](https://github.com/huggingface/transformers/blob/bdf36dcd48106a4a0278ed7f3cc26cd65ab7b066/src/transformers/models/llama/modeling_llama.py#L1084)
73+
74+
75+
### Step 3 Inspect the weight names in JetStream-pt model implementation:
76+
77+
Run the model in JetStream using benchmarks/run_offline.py. The weight names, shape and dtype will be printed in the log (Omitting Layer N which are duplicated names)
78+
79+
Example:
80+
81+
```
82+
Name: freqs_cis, shape: (2048, 64) x complex64
83+
Name: tok_embeddings.weight, shape: (32000, 4096) x bfloat16
84+
Name: layers.0.attention.wo.weight, shape: (4096, 4096) x bfloat16
85+
Name: layers.0.attention.wq.weight, shape: (4096, 4096) x bfloat16
86+
Name: layers.0.attention.wk.weight, shape: (4096, 4096) x bfloat16
87+
Name: layers.0.attention.wv.weight, shape: (4096, 4096) x bfloat16
88+
Name: layers.0.feed_forward.w1.weight, shape: (11008, 4096) x bfloat16
89+
Name: layers.0.feed_forward.w2.weight, shape: (4096, 11008) x bfloat16
90+
Name: layers.0.feed_forward.w3.weight, shape: (11008, 4096) x bfloat16
91+
Name: layers.0.attention_norm.weight, shape: (4096,) x bfloat16
92+
Name: layers.0.ffn_norm.weight, shape: (4096,) x bfloat16
93+
Name: norm.weight, shape: (4096,) x bfloat16
94+
Name: output.weight, shape: (32000, 4096) x bfloat16
95+
```
96+
97+
If it’s hard to tell which layer the weight is for, you can find out the meaning of the weight, please check the model implementation under jetstream_pt/third_party.
98+
99+
### Step 4 By comparing the weight names, or diving into the model code, we can find out the mapping:
100+
101+
In this example:
102+
103+
HF lm_head.weight -> JetStream-pt output.weight
104+
HF model.norm.weight -> JetStream-pt norm.weight
105+
HF model.embed_tokens.weight -> JetStream-pt tok_embeddings.weight
106+
HF model.layers.X.input_layernorm.weight -> layers.X.attention_norm.weight
107+
HF model.layers.0.post_attention_layernorm.weight -> layers.0.ffn_norm.weight
108+
HF model.layers.X.self_attn.{q/k/v/o}_proj.weight -> layers.X.attention.w{q/k/v/o}.weight
109+
HF model.layers.X.mlp.gate_proj.weight -> layers.X.feed_forward.w1.weight
110+
HF model.layers.X.mlp.down_proj.weight -> layers.X.feed_forward.w2.weight
111+
HF model.layers.X.mlp.up_proj.weight -> layers.X.feed_forward.w3.weight
112+
freqs_cis is a special case, in JetStream PyTorch, the weight is pre-computed during weight loading, so no need to map the Huggingface freq weight over.
113+
114+
### Step 5 Validate the converted checkpoint:
115+
116+
If there is a checkpoint in already supported format, convert the checkpoint in supported format first, as the golden data to compare with the converted checkpoint from the new format.
117+
118+
Write a small script, or reuse the [script](https://github.com/google/jetstream-pytorch/blob/main/scripts/validate_hf_ckpt_conversion.py) to compare the 2 converted checkpoints.
119+
120+
Fix the difference between 2 converted checkpoints if there is any. (This will be model and checkpoint format specific)
121+
122+
### Step 6 End-to-end validation: From checkpoint conversion to serving
123+
124+
Example
125+
126+
```
127+
export input_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/7B-FT-chat
128+
export output_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16_2
129+
export model_name="llama"
130+
export from_hf=True
131+
python -m convert_checkpoints --model_name=$model_name \
132+
--input_checkpoint_dir=$input_ckpt_dir \
133+
--output_checkpoint_dir=$output_ckpt_dir \
134+
--quantize_weights=$quantize_weights \
135+
--quantize_type=$quantize_type \
136+
--from_hf=True
137+
```

0 commit comments

Comments
 (0)