-
Notifications
You must be signed in to change notification settings - Fork 17
Support HF LLaMA ckpt conversion #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ | |
from jetstream_pt.config import FLAGS | ||
from jetstream_pt.third_party.gemma import model as gemma_model | ||
from jetstream_pt.third_party.llama import model_exportable as llama_model | ||
from safetensors import safe_open | ||
from safetensors.torch import save_file | ||
|
||
_INPUT_CHECKPOINT_DIR = epath.DEFINE_path( | ||
|
@@ -69,6 +70,12 @@ | |
"When set to true, save to HugginFace SafeTensors format", | ||
) | ||
|
||
_FROM_HF = flags.DEFINE_bool( | ||
"from_hf", | ||
False, | ||
"Set to True if the input is a HuggingFace checkpoint.", | ||
) | ||
|
||
|
||
def _find_scale_name(name, map): | ||
for key, val in map.items(): | ||
|
@@ -252,7 +259,7 @@ def _load_from_gcs(input_ckpt_dir: epath.Path): | |
return checkpoints, params | ||
|
||
|
||
def _load_from_local(input_ckpt_dir: epath.Path): | ||
def _load_orig_llama_weight(input_ckpt_dir: epath.Path): | ||
checkpoints = [] | ||
params = json.loads((input_ckpt_dir / "params.json").read_text()) | ||
|
||
|
@@ -268,6 +275,84 @@ def _load_from_local(input_ckpt_dir: epath.Path): | |
return checkpoints, params | ||
|
||
|
||
def _load_hf_llama_weight(input_ckpt_dir: epath.Path): | ||
print(f"Loading checkpoint files from {input_ckpt_dir}.") | ||
safetensors_files = input_ckpt_dir.glob("*.safetensors") | ||
if len(list(safetensors_files)) == 0: | ||
raise ValueError( | ||
f"No *.safetensors found in the input dir {input_ckpt_dir}" | ||
) | ||
checkpoint = {} | ||
for st_f in safetensors_files: | ||
with safe_open(st_f, framework="pt", device="cpu") as f: | ||
for key in f.keys(): | ||
if "inv_freq" in key: | ||
# Don't include 'rotary_emb.inv_freq' in the converted | ||
# checkpoint, because in JetStream implementation we | ||
# precompute it during weight loading. | ||
continue | ||
new_key = key | ||
# Remove 'model.' prefix for all weights. | ||
prefix_to_remove = "model." | ||
if key.startswith(prefix_to_remove): | ||
new_key = new_key.removeprefix(prefix_to_remove) | ||
|
||
# Weight name substring mapping between hf and jetstream. | ||
_load_hf_llama_weight.hf_to_jetstream_keys_mapping = { | ||
"lm_head": "output", | ||
"embed_tokens": "tok_embeddings", | ||
"input_layernorm": "attention_norm", | ||
"post_attention_layernorm": "ffn_norm", | ||
"self_attn.q_proj": "attention.wq", | ||
"self_attn.k_proj": "attention.wk", | ||
"self_attn.v_proj": "attention.wv", | ||
"self_attn.o_proj": "attention.wo", | ||
"mlp.gate_proj": "feed_forward.w1", | ||
"mlp.down_proj": "feed_forward.w2", | ||
"mlp.up_proj": "feed_forward.w3", | ||
"model.norm.weight": "norm.weight", | ||
} | ||
found_substute = False | ||
for ( | ||
hf_weight_key | ||
) in _load_hf_llama_weight.hf_to_jetstream_keys_mapping.keys(): | ||
if hf_weight_key in key: | ||
jet_stream_key = _load_hf_llama_weight.hf_to_jetstream_keys_mapping[ | ||
hf_weight_key | ||
] | ||
new_key = new_key.replace(hf_weight_key, jet_stream_key) | ||
found_substute = True | ||
break | ||
assert found_substute, f"No substitute name found for {key}." | ||
print(f"convert weight name {key} to {new_key}.") | ||
weight_tensor = f.get_tensor(key) | ||
if weight_tensor.dtype == torch.float16: | ||
# JetStream expects bf16 weight, since activation is in bf16 | ||
# float16 x bf16 will hit mix precision assertion. | ||
weight_tensor = weight_tensor.to(torch.bfloat16) | ||
print(f"convert weight name {new_key} from float16 to bfloat16.") | ||
if "wq" in new_key or "wk" in new_key: | ||
# In HF weight, wq and wk are interleaved differently | ||
weight_shape = weight_tensor.shape | ||
weight_tensor = ( | ||
weight_tensor.reshape(-1, 2, 64, weight_shape[1]) | ||
.transpose(1, 2) | ||
.reshape(weight_shape) | ||
) | ||
checkpoint[new_key] = weight_tensor | ||
return [checkpoint], None | ||
|
||
|
||
def _load_from_local(input_ckpt_dir: epath.Path): | ||
if not _FROM_HF.value: | ||
return _load_orig_llama_weight(input_ckpt_dir) | ||
else: | ||
assert ( | ||
not FLAGS.quantize_weights | ||
), "Quantization not supported for HF checkpoint." | ||
return _load_hf_llama_weight(input_ckpt_dir) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you test the llama2-70B model? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No I didn't, since I haven't set up multi host yet. But I will do that later |
||
|
||
|
||
def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict): | ||
# pylint: disable-next=all | ||
bucket_name, output_ckpt = str(output_ckpt_dir).split("//")[-1].split("/", 1) | ||
|
@@ -276,11 +361,12 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict): | |
bucket = storage_client.bucket(bucket_name) | ||
|
||
ckpt_blob = bucket.blob(os.path.join(output_ckpt, "consolidated.00.pth")) | ||
param_blob = bucket.blob(os.path.join(output_ckpt, "params.json")) | ||
checklist_blob = bucket.blob(os.path.join(output_ckpt, "checklist.chk")) | ||
with param_blob.open("w") as f: | ||
f.write(json.dumps(params)) | ||
f.close() | ||
if params is not None: | ||
param_blob = bucket.blob(os.path.join(output_ckpt, "params.json")) | ||
with param_blob.open("w") as f: | ||
f.write(json.dumps(params)) | ||
f.close() | ||
with ckpt_blob.open("w") as f: | ||
torch.save(state_dict, f) | ||
f.close() | ||
|
@@ -291,7 +377,8 @@ def _export_to_gcs(output_ckpt_dir: epath.Path, params, state_dict): | |
|
||
def _export_to_local(output_ckpt_dir: epath.Path, params, state_dict): | ||
output_ckpt_dir.mkdir(parents=True, exist_ok=True) | ||
(output_ckpt_dir / "params.json").write_text(json.dumps(params)) | ||
if params is not None: | ||
(output_ckpt_dir / "params.json").write_text(json.dumps(params)) | ||
if _OUTPUT_SAFETENSORS.value: | ||
# safetensors.torch.save_file expects tensor to be contiguous. | ||
state_dict = pytree.tree_map_only( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from safetensors import safe_open | ||
|
||
""" | ||
Script to compare converted checkpoint for debugging purpose. | ||
""" | ||
|
||
converted_from_orig = ( | ||
"/mnt/disks/lsiyuan/llama_weight/7B-FT-chat-converted/model.safetensors" | ||
) | ||
|
||
converted_from_hf = "/mnt/disks/lsiyuan/llama_weight/hf_llama_2_7b_converted_bf16/model.safetensors" | ||
|
||
orig_state_dict = {} | ||
with safe_open(converted_from_orig, framework="pt", device="cpu") as f: | ||
for key in f.keys(): | ||
orig_state_dict[key] = f.get_tensor(key) | ||
|
||
hf_state_dict = {} | ||
with safe_open(converted_from_hf, framework="pt", device="cpu") as f: | ||
for key in f.keys(): | ||
hf_state_dict[key] = f.get_tensor(key) | ||
|
||
for key in orig_state_dict.keys(): | ||
if key != "rope.freqs": | ||
assert key in hf_state_dict, f"{key} in orig but not in hf" | ||
else: | ||
print("rope.freqs skipped.") | ||
|
||
for key in hf_state_dict.keys(): | ||
assert key in orig_state_dict, f"{key} in hf but not in orig" | ||
|
||
|
||
def _calc_cosine_dist(x, y): | ||
x = x.flatten().to(torch.float32) | ||
y = y.flatten().to(torch.float32) | ||
return (torch.dot(x, y) / (x.norm() * y.norm())).item() | ||
|
||
|
||
for key in hf_state_dict.keys(): | ||
orig_w = orig_state_dict[key] | ||
hf_w = hf_state_dict[key] | ||
print(f"weight diff {key} : {_calc_cosine_dist(orig_w, hf_w)}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel [gate|down|up]_proj are more read friendly than w1, w2 and w3. @qihqi Shall we consider rename it to proj related name in default checkpoint convert?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this makes sense. Also want to note that original llama weight is using w1/2/3 https://github.com/meta-llama/llama3/blob/main/llama/model.py#L219. If we change it we need to do the name mapping for the original llama weight.