Skip to content

Commit 94b576c

Browse files
authored
Support HF LLaMA ckpt conversion (#118)
* support converting hf checkpoint
1 parent 52ec00f commit 94b576c

File tree

2 files changed

+136
-6
lines changed

2 files changed

+136
-6
lines changed

convert_checkpoints.py

Lines changed: 93 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from jetstream_pt.config import FLAGS
3838
from jetstream_pt.third_party.gemma import model as gemma_model
3939
from jetstream_pt.third_party.llama import model_exportable as llama_model
40+
from safetensors import safe_open
4041
from safetensors.torch import save_file
4142

4243
_INPUT_CHECKPOINT_DIR = epath.DEFINE_path(
@@ -69,6 +70,12 @@
6970
"When set to true, save to HugginFace SafeTensors format",
7071
)
7172

73+
_FROM_HF = flags.DEFINE_bool(
74+
"from_hf",
75+
False,
76+
"Set to True if the input is a HuggingFace checkpoint.",
77+
)
78+
7279

7380
def _find_scale_name(name, map):
7481
for key, val in map.items():
@@ -252,7 +259,7 @@ def _load_from_gcs(input_ckpt_dir: epath.Path):
252259
return checkpoints, params
253260

254261

255-
def _load_from_local(input_ckpt_dir: epath.Path):
262+
def _load_orig_llama_weight(input_ckpt_dir: epath.Path):
256263
checkpoints = []
257264
params = json.loads((input_ckpt_dir / "params.json").read_text())
258265

@@ -268,6 +275,84 @@ def _load_from_local(input_ckpt_dir: epath.Path):
268275
return checkpoints, params
269276

270277

278+
def _load_hf_llama_weight(input_ckpt_dir: epath.Path):
279+
print(f"Loading checkpoint files from {input_ckpt_dir}.")
280+
safetensors_files = input_ckpt_dir.glob("*.safetensors")
281+
if len(list(safetensors_files)) == 0:
282+
raise ValueError(
283+
f"No *.safetensors found in the input dir {input_ckpt_dir}"
284+
)
285+
checkpoint = {}
286+
for st_f in safetensors_files:
287+
with safe_open(st_f, framework="pt", device="cpu") as f:
288+
for key in f.keys():
289+
if "inv_freq" in key:
290+
# Don't include 'rotary_emb.inv_freq' in the converted
291+
# checkpoint, because in JetStream implementation we
292+
# precompute it during weight loading.
293+
continue
294+
new_key = key
295+
# Remove 'model.' prefix for all weights.
296+
prefix_to_remove = "model."
297+
if key.startswith(prefix_to_remove):
298+
new_key = new_key.removeprefix(prefix_to_remove)
299+
300+
# Weight name substring mapping between hf and jetstream.
301+
_load_hf_llama_weight.hf_to_jetstream_keys_mapping = {
302+
"lm_head": "output",
303+
"embed_tokens": "tok_embeddings",
304+
"input_layernorm": "attention_norm",
305+
"post_attention_layernorm": "ffn_norm",
306+
"self_attn.q_proj": "attention.wq",
307+
"self_attn.k_proj": "attention.wk",
308+
"self_attn.v_proj": "attention.wv",
309+
"self_attn.o_proj": "attention.wo",
310+
"mlp.gate_proj": "feed_forward.w1",
311+
"mlp.down_proj": "feed_forward.w2",
312+
"mlp.up_proj": "feed_forward.w3",
313+
"model.norm.weight": "norm.weight",
314+
}
315+
found_substute = False
316+
for (
317+
hf_weight_key
318+
) in _load_hf_llama_weight.hf_to_jetstream_keys_mapping.keys():
319+
if hf_weight_key in key:
320+
jet_stream_key = _load_hf_llama_weight.hf_to_jetstream_keys_mapping[
321+
hf_weight_key
322+
]
323+
new_key = new_key.replace(hf_weight_key, jet_stream_key)
324+
found_substute = True
325+
break
326+
assert found_substute, f"No substitute name found for {key}."
327+
print(f"convert weight name {key} to {new_key}.")
328+
weight_tensor = f.get_tensor(key)
329+
if weight_tensor.dtype == torch.float16:
330+
# JetStream expects bf16 weight, since activation is in bf16
331+
# float16 x bf16 will hit mix precision assertion.
332+
weight_tensor = weight_tensor.to(torch.bfloat16)
333+
print(f"convert weight name {new_key} from float16 to bfloat16.")
334+
if "wq" in new_key or "wk" in new_key:
335+
# In HF weight, wq and wk are interleaved differently
336+
weight_shape = weight_tensor.shape
337+
weight_tensor = (
338+
weight_tensor.reshape(-1, 2, 64, weight_shape[1])
339+
.transpose(1, 2)
340+
.reshape(weight_shape)
341+
)
342+
checkpoint[new_key] = weight_tensor
343+
return [checkpoint], None
344+
345+
346+
def _load_from_local(input_ckpt_dir: epath.Path):
347+
if not _FROM_HF.value:
348+
return _load_orig_llama_weight(input_ckpt_dir)
349+
else:
350+
assert (
351+
not FLAGS.quantize_weights
352+
), "Quantization not supported for HF checkpoint."
353+
return _load_hf_llama_weight(input_ckpt_dir)
354+
355+
271356
def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict):
272357
# pylint: disable-next=all
273358
bucket_name, output_ckpt = str(output_ckpt_dir).split("//")[-1].split("/", 1)
@@ -276,11 +361,12 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict):
276361
bucket = storage_client.bucket(bucket_name)
277362

278363
ckpt_blob = bucket.blob(os.path.join(output_ckpt, "consolidated.00.pth"))
279-
param_blob = bucket.blob(os.path.join(output_ckpt, "params.json"))
280364
checklist_blob = bucket.blob(os.path.join(output_ckpt, "checklist.chk"))
281-
with param_blob.open("w") as f:
282-
f.write(json.dumps(params))
283-
f.close()
365+
if params is not None:
366+
param_blob = bucket.blob(os.path.join(output_ckpt, "params.json"))
367+
with param_blob.open("w") as f:
368+
f.write(json.dumps(params))
369+
f.close()
284370
with ckpt_blob.open("w") as f:
285371
torch.save(state_dict, f)
286372
f.close()
@@ -291,7 +377,8 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict):
291377

292378
def _export_to_local(output_ckpt_dir: epath.Path, params, state_dict):
293379
output_ckpt_dir.mkdir(parents=True, exist_ok=True)
294-
(output_ckpt_dir / "params.json").write_text(json.dumps(params))
380+
if params is not None:
381+
(output_ckpt_dir / "params.json").write_text(json.dumps(params))
295382
if _OUTPUT_SAFETENSORS.value:
296383
# safetensors.torch.save_file expects tensor to be contiguous.
297384
state_dict = pytree.tree_map_only(
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from safetensors import safe_open
3+
4+
"""
5+
Script to compare converted checkpoint for debugging purpose.
6+
"""
7+
8+
converted_from_orig = (
9+
"/mnt/disks/lsiyuan/llama_weight/7B-FT-chat-converted/model.safetensors"
10+
)
11+
12+
converted_from_hf = "/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16/model.safetensors"
13+
14+
orig_state_dict = {}
15+
with safe_open(converted_from_orig, framework="pt", device="cpu") as f:
16+
for key in f.keys():
17+
orig_state_dict[key] = f.get_tensor(key)
18+
19+
hf_state_dict = {}
20+
with safe_open(converted_from_hf, framework="pt", device="cpu") as f:
21+
for key in f.keys():
22+
hf_state_dict[key] = f.get_tensor(key)
23+
24+
for key in orig_state_dict.keys():
25+
if key != "rope.freqs":
26+
assert key in hf_state_dict, f"{key} in orig but not in hf"
27+
else:
28+
print("rope.freqs skipped.")
29+
30+
for key in hf_state_dict.keys():
31+
assert key in orig_state_dict, f"{key} in hf but not in orig"
32+
33+
34+
def _calc_cosine_dist(x, y):
35+
x = x.flatten().to(torch.float32)
36+
y = y.flatten().to(torch.float32)
37+
return (torch.dot(x, y) / (x.norm() * y.norm())).item()
38+
39+
40+
for key in hf_state_dict.keys():
41+
orig_w = orig_state_dict[key]
42+
hf_w = hf_state_dict[key]
43+
print(f"weight diff {key} : {_calc_cosine_dist(orig_w, hf_w)}")

0 commit comments

Comments
 (0)