2121import torch .nn .functional as F
2222from torch import nn
2323
24- from nemo .collections .llm .gpt .model .base import GPTConfig , GPTModel
24+ from nemo .collections .llm .gpt .model .base import GPTConfig , GPTModel , torch_dtype_from_mcore_config
2525from nemo .collections .llm .utils import Config
2626from nemo .lightning import OptimizerModule , io , teardown
2727from nemo .lightning .pytorch .utils import dtype_from_hf
@@ -295,16 +295,16 @@ def make_vocab_size_divisible_by(vocab_size):
295295
296296@io .model_exporter (LlamaModel , "hf" )
297297class HFLlamaExporter (io .ModelConnector [LlamaModel , "LlamaForCausalLM" ]):
298- def init (self ) -> "LlamaForCausalLM" :
298+ def init (self , dtype = torch . bfloat16 ) -> "LlamaForCausalLM" :
299299 from transformers import AutoModelForCausalLM
300300 from transformers .modeling_utils import no_init_weights
301301
302302 with no_init_weights (True ):
303- return AutoModelForCausalLM .from_config (self .config )
303+ return AutoModelForCausalLM .from_config (self .config , torch_dtype = dtype )
304304
305305 def apply (self , output_path : Path ) -> Path :
306- target = self .init ()
307306 source , _ = self .nemo_load (str (self ))
307+ target = self .init (torch_dtype_from_mcore_config (source .config ))
308308 target = self .convert_state (source , target )
309309
310310 target = target .cpu ()
@@ -321,10 +321,9 @@ def convert_state(self, source, target):
321321 "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight" : "model.layers.*.input_layernorm.weight" ,
322322 "decoder.layers.*.mlp.linear_fc1.layer_norm_weight" : "model.layers.*.post_attention_layernorm.weight" ,
323323 "decoder.final_layernorm.weight" : "model.norm.weight" ,
324- "output_layer.weight" : "lm_head.weight" ,
325324 }
326325
327- return io .apply_transforms (source , target , mapping = mapping , transforms = [_export_qkv , _export_linear_fc1 ])
326+ return io .apply_transforms (source , target , mapping = mapping , transforms = [_export_qkv , _export_linear_fc1 , _export_embedding , _export_head ])
328327
329328 @property
330329 def tokenizer (self ):
@@ -426,6 +425,26 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
426425 return q_proj , k_proj , v_proj
427426
428427
428+ @io .state_transform (
429+ source_key = "embedding.word_embeddings.weight" ,
430+ target_key = "model.embed_tokens.weight" ,
431+ )
432+ def _export_embedding (ctx : io .TransformCTX , embedding ):
433+ megatron_config = ctx .target .config
434+ # prune padding.
435+ return embedding [:megatron_config .vocab_size , :]
436+
437+
438+ @io .state_transform (
439+ source_key = "output_layer.weight" ,
440+ target_key = "lm_head.weight" ,
441+ )
442+ def _export_head (ctx : io .TransformCTX , embedding ):
443+ megatron_config = ctx .target .config
444+ # prune padding.
445+ return embedding [:megatron_config .vocab_size , :]
446+
447+
429448@io .state_transform (
430449 source_key = ("model.layers.*.mlp.gate_proj.weight" , "model.layers.*.mlp.up_proj.weight" ),
431450 target_key = "decoder.layers.*.mlp.linear_fc1.weight" ,
@@ -443,6 +462,15 @@ def _export_linear_fc1(linear_fc1):
443462
444463 return gate_proj , up_proj
445464
465+ # @io.state_transform(
466+ # source_key="decoder.layers.*.mlp.linear_fc1.weight",
467+ # target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
468+ # )
469+ # def _export_embedding(linear_fc1):
470+ # gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)
471+
472+ # return gate_proj, up_proj
473+
446474
447475def apply_rope_scaling (
448476 inv_freq ,
0 commit comments