|
| 1 | +# Guide on adding HuggingFace checkpoint conversion support |
| 2 | + |
| 3 | +## Prerequisites: |
| 4 | +The model implementation has been added in JetStream-pt |
| 5 | +The checkpoint conversion from a certain format is already supported. (Or no conversion is needed for the checkpoint) |
| 6 | + |
| 7 | +Please check this [guide](https://github.com/google/jetstream-pytorch/blob/main/docs/add_a_new_model.md) for adding a new model. |
| 8 | + |
| 9 | +## Use case: |
| 10 | +The user has the checkpoint for the same model architecture in another format (e.g. HF format for LLaMA model). And want to have JetStream-pt support this checkpoint format. |
| 11 | + |
| 12 | +## Guide |
| 13 | + |
| 14 | +Converting a public checkpoint to JetStream-pt format is mostly about finding the weight key mapping between the public checkpoint and JetStream model implementation. Besides the name mapping, the layout of the weights might be different among different checkpoint formats (e.g. Weight interleaved differently due to difference in Rotary Embedding implementation). These differences are model and checkpoint format specific. |
| 15 | + |
| 16 | +**Note** The model code and checkpoint format can be different from model to model, the following guide demonstrate a general guide, specific models may require additional effort for the checkpoint conversion support. |
| 17 | + |
| 18 | +The checkpoint conversion logic in the checkpoint conversion script. |
| 19 | + |
| 20 | +### Step 1 Find the HuggingFace checkpoint you want to convert |
| 21 | +In this example, let’s use meta-llama/llama-2 7B as an example |
| 22 | + |
| 23 | +You can download the checkpoints to a local folder using |
| 24 | +huggingface-cli download meta-llama/Llama-2-7b-hf --local-dir Llama-2-7b-hf |
| 25 | + |
| 26 | + |
| 27 | +**Note** You may need to go to Huggingface website to sign an agreement to get the permission to download the model |
| 28 | + |
| 29 | +### Step 2 Inspect the weight names in the checkpoint: |
| 30 | + |
| 31 | +Usually there is a model.safetensors.index.json file in the checkpoint. [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/model.safetensors.index.json) |
| 32 | + |
| 33 | +Alternatively, you can load the weights locally and inspect the model key names(Usually it’s in safetensor format, and it’s sharded) |
| 34 | + |
| 35 | +Example script: |
| 36 | +```Python |
| 37 | +import glob |
| 38 | +import os |
| 39 | +import torch |
| 40 | +from safetensors import safe_open |
| 41 | + |
| 42 | +checkpoint_folder = "/mnt/disks/lsiyuan/llama_weight/Meta-Llama-3-8B-Instruct" |
| 43 | + |
| 44 | +safetensor_files = glob.glob(os.path.join(checkpoint_folder, "*.safetensors")) |
| 45 | + |
| 46 | +for st_f in safetensor_files: |
| 47 | + with safe_open(st_f, framework="pt", device="cpu") as f: |
| 48 | + for key in f.keys(): |
| 49 | + weight_tensor = f.get_tensor(key) |
| 50 | + print(f"Weight name {key}, Shape: {weight_tensor.shape}, dtype: {weight_tensor.dtype}") |
| 51 | +``` |
| 52 | + |
| 53 | +Got the following output: |
| 54 | + |
| 55 | +``` |
| 56 | +lm_head.weight torch.Size([32000, 4096]) x torch.float16 |
| 57 | +model.norm.weight torch.Size([4096]) x torch.float16 |
| 58 | +model.embed_tokens.weight torch.Size([32000, 4096]) x torch.float16 |
| 59 | +model.layers.0.input_layernorm.weight torch.Size([4096]) x torch.float16 |
| 60 | +model.layers.0.mlp.down_proj.weight torch.Size([4096, 11008]) x torch.float16 |
| 61 | +model.layers.0.mlp.gate_proj.weight torch.Size([11008, 4096]) x torch.float16 |
| 62 | +model.layers.0.mlp.up_proj.weight torch.Size([11008, 4096]) x torch.float16 |
| 63 | +model.layers.0.post_attention_layernorm.weight torch.Size([4096]) x torch.float16 |
| 64 | +model.layers.0.self_attn.k_proj.weight torch.Size([4096, 4096]) x torch.float16 |
| 65 | +model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096]) x torch.float16 |
| 66 | +model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096]) x torch.float16 |
| 67 | +model.layers.0.self_attn.rotary_emb.inv_freq torch.Size([64]) x torch.float32 |
| 68 | +model.layers.0.self_attn.v_proj.weight torch.Size([4096, 4096]) x torch.float16 |
| 69 | +… # Duplicated name for model.layers.x |
| 70 | +``` |
| 71 | + |
| 72 | +If it’s hard to tell which layer the weight is for, the HF model class can be checked in the checkpoint config file [example](https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json#L4). Then we can find the model code in the transformer repo by searching the model class name [model code](https://github.com/huggingface/transformers/blob/bdf36dcd48106a4a0278ed7f3cc26cd65ab7b066/src/transformers/models/llama/modeling_llama.py#L1084) |
| 73 | + |
| 74 | + |
| 75 | +### Step 3 Inspect the weight names in JetStream-pt model implementation: |
| 76 | + |
| 77 | +Run the model in JetStream using benchmarks/run_offline.py. The weight names, shape and dtype will be printed in the log (Omitting Layer N which are duplicated names) |
| 78 | + |
| 79 | +Example: |
| 80 | + |
| 81 | +``` |
| 82 | +Name: freqs_cis, shape: (2048, 64) x complex64 |
| 83 | +Name: tok_embeddings.weight, shape: (32000, 4096) x bfloat16 |
| 84 | +Name: layers.0.attention.wo.weight, shape: (4096, 4096) x bfloat16 |
| 85 | +Name: layers.0.attention.wq.weight, shape: (4096, 4096) x bfloat16 |
| 86 | +Name: layers.0.attention.wk.weight, shape: (4096, 4096) x bfloat16 |
| 87 | +Name: layers.0.attention.wv.weight, shape: (4096, 4096) x bfloat16 |
| 88 | +Name: layers.0.feed_forward.w1.weight, shape: (11008, 4096) x bfloat16 |
| 89 | +Name: layers.0.feed_forward.w2.weight, shape: (4096, 11008) x bfloat16 |
| 90 | +Name: layers.0.feed_forward.w3.weight, shape: (11008, 4096) x bfloat16 |
| 91 | +Name: layers.0.attention_norm.weight, shape: (4096,) x bfloat16 |
| 92 | +Name: layers.0.ffn_norm.weight, shape: (4096,) x bfloat16 |
| 93 | +Name: norm.weight, shape: (4096,) x bfloat16 |
| 94 | +Name: output.weight, shape: (32000, 4096) x bfloat16 |
| 95 | +``` |
| 96 | + |
| 97 | +If it’s hard to tell which layer the weight is for, you can find out the meaning of the weight, please check the model implementation under jetstream_pt/third_party. |
| 98 | + |
| 99 | +### Step 4 By comparing the weight names, or diving into the model code, we can find out the mapping: |
| 100 | + |
| 101 | + In this example: |
| 102 | + |
| 103 | +HF lm_head.weight -> JetStream-pt output.weight |
| 104 | +HF model.norm.weight -> JetStream-pt norm.weight |
| 105 | +HF model.embed_tokens.weight -> JetStream-pt tok_embeddings.weight |
| 106 | +HF model.layers.X.input_layernorm.weight -> layers.X.attention_norm.weight |
| 107 | +HF model.layers.0.post_attention_layernorm.weight -> layers.0.ffn_norm.weight |
| 108 | +HF model.layers.X.self_attn.{q/k/v/o}_proj.weight -> layers.X.attention.w{q/k/v/o}.weight |
| 109 | +HF model.layers.X.mlp.gate_proj.weight -> layers.X.feed_forward.w1.weight |
| 110 | +HF model.layers.X.mlp.down_proj.weight -> layers.X.feed_forward.w2.weight |
| 111 | +HF model.layers.X.mlp.up_proj.weight -> layers.X.feed_forward.w3.weight |
| 112 | +freqs_cis is a special case, in JetStream PyTorch, the weight is pre-computed during weight loading, so no need to map the Huggingface freq weight over. |
| 113 | + |
| 114 | +### Step 5 Validate the converted checkpoint: |
| 115 | + |
| 116 | +If there is a checkpoint in already supported format, convert the checkpoint in supported format first, as the golden data to compare with the converted checkpoint from the new format. |
| 117 | + |
| 118 | +Write a small script, or reuse the [script](https://github.com/google/jetstream-pytorch/blob/main/scripts/validate_hf_ckpt_conversion.py) to compare the 2 converted checkpoints. |
| 119 | + |
| 120 | +Fix the difference between 2 converted checkpoints if there is any. (This will be model and checkpoint format specific) |
| 121 | + |
| 122 | +### Step 6 End-to-end validation: From checkpoint conversion to serving |
| 123 | + |
| 124 | +Example |
| 125 | + |
| 126 | +``` |
| 127 | +export input_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/7B-FT-chat |
| 128 | +export output_ckpt_dir=/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16_2 |
| 129 | +export model_name="llama" |
| 130 | +export from_hf=True |
| 131 | +python -m convert_checkpoints --model_name=$model_name \ |
| 132 | + --input_checkpoint_dir=$input_ckpt_dir \ |
| 133 | + --output_checkpoint_dir=$output_ckpt_dir \ |
| 134 | + --quantize_weights=$quantize_weights \ |
| 135 | + --quantize_type=$quantize_type \ |
| 136 | + --from_hf=True |
| 137 | +``` |
0 commit comments