Skip to content

Support HF LLaMA ckpt conversion #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 7, 2024

Conversation

lsy323
Copy link
Collaborator

@lsy323 lsy323 commented Jun 7, 2024

Added --from_hf option in convert_checkpoint.py for HF checkpoint. Only LLaMA is supported now. Quantization conversion is not supported with HF checkpoint.

Enable converting HF llama checkpoint by

python -m convert_checkpoints --model_name=llama-2 \
    --input_checkpoint_dir=$input_ckpt_dir \
    --output_checkpoint_dir=$output_ckpt_dir \
    --from_hf=True

The guide to add support for HF checkpoint will be done in a following PR.

Only tested with HF 7B model, 70B not tested yet

Copy link
Collaborator

@FanhaiLu1 FanhaiLu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for support HF llama chpt conversion! Can you save hf weight names as a file in the repo?

"self_attn.k_proj": "attention.wk",
"self_attn.v_proj": "attention.wv",
"self_attn.o_proj": "attention.wo",
"mlp.gate_proj": "feed_forward.w1",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel [gate|down|up]_proj are more read friendly than w1, w2 and w3. @qihqi Shall we consider rename it to proj related name in default checkpoint convert?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this makes sense. Also want to note that original llama weight is using w1/2/3 https://github.com/meta-llama/llama3/blob/main/llama/model.py#L219. If we change it we need to do the name mapping for the original llama weight.

assert (
not FLAGS.quantize_weights
), "Quantization not supported for HF checkpoint."
return _load_hf_llama_weight(input_ckpt_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test the llama2-70B model?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I didn't, since I haven't set up multi host yet. But I will do that later

@qihqi qihqi merged commit 94b576c into AI-Hypercomputer:main Jun 7, 2024
4 checks passed
@lsy323 lsy323 deleted the lsiyuan/hf-ckpt branch June 7, 2024 04:24
@lsy323 lsy323 restored the lsiyuan/hf-ckpt branch June 7, 2024 04:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants