Skip to content

Commit 2859e47

Browse files
authored
Add qwen 2.5 (#8355)
1 parent 728c255 commit 2859e47

File tree

7 files changed

+189
-10
lines changed

7 files changed

+189
-10
lines changed

examples/models/llama/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,16 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
175175
self.max_batch_size = args.max_batch_size
176176
self.max_context_len = args.max_context_len
177177
self.dim = args.dim
178-
self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
179-
self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
180-
self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
178+
self.attention_qkv_bias = args.attention_qkv_bias
179+
self.wq = nn.Linear(
180+
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
181+
)
182+
self.wk = nn.Linear(
183+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
184+
)
185+
self.wv = nn.Linear(
186+
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
187+
)
181188
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
182189

183190
self.layer_id = layer_id

examples/models/llama/model_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class ModelArgs:
2121
num_experts: int = 8 # Number of experts
2222
num_activated_experts: int = 2 # Number of experts to activate
2323
attention_type: str = "mha" # Attention type, registered in attention.py
24+
attention_qkv_bias: bool = False
2425
use_kv_cache: bool = False # Use key/value cache
2526
use_sdpa_with_kv_cache_op: bool = (
2627
False # Use custom sdpa op that updates kv cache in-place

examples/models/llama/rope.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def apply_rotary_emb_to_k(
114114
return xk_out.type_as(xk)
115115

116116

117+
# Wrap apply_rotary_emb in a module to enable it to be module swapped out.
117118
class RotaryEmbedding(torch.nn.Module):
118119
def __init__(self):
119120
super().__init__()
@@ -213,14 +214,20 @@ class Rope(torch.nn.Module):
213214
def __init__(self, params: ModelArgs):
214215
super().__init__()
215216
self.params = params
217+
218+
# Choose the appropriate RoPE implementation
216219
if self.params.use_hf_rope:
217220
self.precompute_freqs_cis = hf_precompute_freqs_cis
221+
self.apply_rotary_emb = hf_apply_rotary_emb
218222
else:
219223
self.precompute_freqs_cis = partial(
220224
precompute_freqs_cis,
221225
use_scaled=self.params.use_scaled_rope,
222226
scale_factor=self.params.rope_scale_factor,
223227
)
228+
self.apply_rotary_emb = RotaryEmbedding()
229+
230+
# Precompute frequencies
224231
freqs_cos, freqs_sin = self.precompute_freqs_cis(
225232
self.params.head_dim,
226233
(
@@ -232,10 +239,6 @@ def __init__(self, params: ModelArgs):
232239
)
233240
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
234241
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
235-
if self.params.use_hf_rope:
236-
self.apply_rotary_emb = hf_apply_rotary_emb
237-
else:
238-
self.apply_rotary_emb = RotaryEmbedding()
239242

240243
def forward(
241244
self,

examples/models/llama/static_attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,23 @@ def __init__(self, config: ModelArgs, layer_id: int, rope: Rope):
207207
self.dim = config.dim
208208
self.head_dim = config.head_dim
209209
self.inv_scale = 1.0 / (float(self.head_dim) ** 0.5)
210+
self.attention_qkv_bias = config.attention_qkv_bias
210211

211212
self.wqs = nn.ModuleList(
212213
[
213-
nn.Linear(self.dim, self.head_dim, bias=False)
214+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
214215
for _ in range(self.n_heads)
215216
]
216217
)
217218
self.wks = nn.ModuleList(
218219
[
219-
nn.Linear(self.dim, self.head_dim, bias=False)
220+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
220221
for _ in range(self.n_kv_heads)
221222
]
222223
)
223224
self.wvs = nn.ModuleList(
224225
[
225-
nn.Linear(self.dim, self.head_dim, bias=False)
226+
nn.Linear(self.dim, self.head_dim, bias=self.attention_qkv_bias)
226227
for _ in range(self.n_kv_heads)
227228
]
228229
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"dim": 1536,
3+
"ffn_dim_multiplier": 1,
4+
"hidden_dim": 8960,
5+
"n_heads": 12,
6+
"n_kv_heads": 2,
7+
"n_layers": 28,
8+
"norm_eps": 1e-06,
9+
"rope_theta": 1000000.0,
10+
"use_scaled_rope": false,
11+
"vocab_size": 151936,
12+
"use_hf_rope": true,
13+
"attention_qkv_bias": true
14+
}

examples/models/qwen2_5/README.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
## Summary
2+
Qwen 2.5 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. At the moment, 1.5b is currently supporting, with plans in the future for adding the 0.5b and 3b versions.
3+
4+
## Instructions
5+
6+
Qwen 2.5 uses the same example code as Llama, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details.
7+
8+
All commands for exporting and running Llama on various backends should also be applicable to Qwen 2.5, by swapping the following args:
9+
```
10+
--model qwen2_5
11+
--params examples/models/qwen2_5/1_5b_config.json
12+
--checkpoint <path-to-meta-checkpoint>
13+
```
14+
15+
### Generate the Checkpoint
16+
The original checkpoint can be obtained from HuggingFace:
17+
```
18+
huggingface-cli download Qwen/Qwen2.5-1.5B
19+
```
20+
21+
We then convert it to Meta's checkpoint format:
22+
```
23+
python examples/models/qwen2_5/convert_weights.py <path-to-checkpoint-dir> <output-path>
24+
```
25+
26+
### Example export and run
27+
Here is an basic example for exporting and running Qwen 2.5, although please refer to [Llama README page](../llama/README.md) for more advanced usage.
28+
29+
Export to XNNPack, no quantization:
30+
```
31+
# No quantization
32+
# Set these paths to point to the downloaded files
33+
QWEN_CHECKPOINT=path/to/checkpoint.pth
34+
35+
python -m examples.models.llama.export_llama \
36+
--model "qwen2_5" \
37+
--checkpoint "${QWEN_CHECKPOINT:?}" \
38+
--params examples/models/qwen2_5/1_5b_config.json \
39+
-kv \
40+
--use_sdpa_with_kv_cache \
41+
-d fp32 \
42+
-X \
43+
--metadata '{"get_bos_id":151643, "get_eos_ids":[151643]}' \
44+
--output_name="qwen2_5-1_5b.pte"
45+
--verbose
46+
```
47+
48+
Run using the executor runner:
49+
```
50+
# Currently a work in progress, just need to enable HuggingFace json tokenizer in C++.
51+
# In the meantime, can run with an example Python runner with pybindings:
52+
53+
python -m examples.models.llama.runner.native
54+
--model qwen2_5
55+
--pte <path-to-pte>
56+
-kv
57+
--tokenizer <path-to-tokenizer>/tokenizer.json
58+
--tokenizer_config <path-to_tokenizer>/tokenizer_config.json
59+
--prompt "Who is the founder of Meta?"
60+
--params examples/models/qwen2_5/1_5b_config.json
61+
--max_len 64
62+
--temperature 0
63+
```
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
from typing import Dict
3+
4+
import torch
5+
6+
from torchtune.models.convert_weights import get_mapped_key
7+
8+
from torchtune.training import FullModelHFCheckpointer
9+
10+
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
11+
_QWEN_2_FROM_META = {
12+
"tok_embeddings.weight": "tok_embeddings.weight",
13+
"norm.weight": "norm.scale",
14+
"layers.{}.attention.wk.weight": "layers.{}.attn.k_proj.weight",
15+
"layers.{}.attention.wk.bias": "layers.{}.attn.k_proj.bias",
16+
"layers.{}.attention.wq.weight": "layers.{}.attn.q_proj.weight",
17+
"layers.{}.attention.wq.bias": "layers.{}.attn.q_proj.bias",
18+
"layers.{}.attention.wv.weight": "layers.{}.attn.v_proj.weight",
19+
"layers.{}.attention.wv.bias": "layers.{}.attn.v_proj.bias",
20+
"layers.{}.attention.wo.weight": "layers.{}.attn.output_proj.weight",
21+
"layers.{}.attention_norm.weight": "layers.{}.sa_norm.scale",
22+
"layers.{}.ffn_norm.weight": "layers.{}.mlp_norm.scale",
23+
"layers.{}.feed_forward.w1.weight": "layers.{}.mlp.w1.weight",
24+
"layers.{}.feed_forward.w2.weight": "layers.{}.mlp.w2.weight",
25+
"layers.{}.feed_forward.w3.weight": "layers.{}.mlp.w3.weight",
26+
}
27+
28+
29+
def qwen_2_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
30+
"""
31+
Convert a state dict from torchtune's format to Meta's format. This function
32+
doesn't handle any sharding or splitting of state dicts. It follows the
33+
state_dict IN -> state_dict OUT pattern.
34+
35+
Args:
36+
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.
37+
38+
Returns:
39+
Dict[str, torch.Tensor]: State dict in Meta's format.
40+
"""
41+
converted_state_dict = {}
42+
inverted_mapping_dict = {v: k for k, v in _QWEN_2_FROM_META.items()}
43+
44+
for key, value in state_dict.items():
45+
new_key = get_mapped_key(key, inverted_mapping_dict)
46+
converted_state_dict[new_key] = value
47+
48+
# 0.5b and 1.5b models share the same weights for tok_embeddings and output embeddings, see https://github.com/QwenLM/Qwen2.5/issues/733.
49+
converted_state_dict["output.weight"] = converted_state_dict[
50+
"tok_embeddings.weight"
51+
]
52+
53+
return converted_state_dict
54+
55+
56+
def main():
57+
parser = argparse.ArgumentParser(
58+
description="Convert Qwen2 weights to Meta format."
59+
)
60+
parser.add_argument(
61+
"input_dir",
62+
type=str,
63+
help="Path to directory containing checkpoint files",
64+
)
65+
parser.add_argument("output", type=str, help="Path to the output checkpoint")
66+
67+
args = parser.parse_args()
68+
69+
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
70+
checkpointer = FullModelHFCheckpointer(
71+
# checkpoint_dir="/home/jackzhxng/.cache/huggingface/hub/models--Qwen--Qwen2.5-1.5B/snapshots/8faed761d45a263340a0528343f099c05c9a4323/",
72+
checkpoint_dir=args.input_dir,
73+
checkpoint_files=["model.safetensors"],
74+
output_dir=".",
75+
model_type="QWEN2",
76+
)
77+
78+
print("Loading checkpoint...")
79+
sd = checkpointer.load_checkpoint()
80+
81+
print("Converting checkpoint...")
82+
sd = qwen_2_tune_to_meta(sd["model"])
83+
# torch.save(sd, "/home/jackzhxng/models/qwen2_5-1_5b.pth")
84+
85+
torch.save(sd, args.output)
86+
print(f"Checkpoint saved to {args.output}")
87+
88+
89+
if __name__ == "__main__":
90+
main()

0 commit comments

Comments
 (0)