Skip to content

Commit 3ea99d4

Browse files
committed
fix hf model dtype & prune embedding size
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 08e5fa0 commit 3ea99d4

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

nemo/collections/llm/gpt/model/llama.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch.nn.functional as F
2222
from torch import nn
2323

24-
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
24+
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config
2525
from nemo.collections.llm.utils import Config
2626
from nemo.lightning import OptimizerModule, io, teardown
2727
from nemo.lightning.pytorch.utils import dtype_from_hf
@@ -295,16 +295,16 @@ def make_vocab_size_divisible_by(vocab_size):
295295

296296
@io.model_exporter(LlamaModel, "hf")
297297
class HFLlamaExporter(io.ModelConnector[LlamaModel, "LlamaForCausalLM"]):
298-
def init(self) -> "LlamaForCausalLM":
298+
def init(self, dtype=torch.bfloat16) -> "LlamaForCausalLM":
299299
from transformers import AutoModelForCausalLM
300300
from transformers.modeling_utils import no_init_weights
301301

302302
with no_init_weights(True):
303-
return AutoModelForCausalLM.from_config(self.config)
303+
return AutoModelForCausalLM.from_config(self.config, torch_dtype=dtype)
304304

305305
def apply(self, output_path: Path) -> Path:
306-
target = self.init()
307306
source, _ = self.nemo_load(str(self))
307+
target = self.init(torch_dtype_from_mcore_config(source.config))
308308
target = self.convert_state(source, target)
309309

310310
target = target.cpu()
@@ -321,10 +321,9 @@ def convert_state(self, source, target):
321321
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
322322
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
323323
"decoder.final_layernorm.weight": "model.norm.weight",
324-
"output_layer.weight": "lm_head.weight",
325324
}
326325

327-
return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1])
326+
return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1, _export_embedding, _export_head])
328327

329328
@property
330329
def tokenizer(self):
@@ -426,6 +425,26 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv):
426425
return q_proj, k_proj, v_proj
427426

428427

428+
@io.state_transform(
429+
source_key="embedding.word_embeddings.weight",
430+
target_key="model.embed_tokens.weight",
431+
)
432+
def _export_embedding(ctx: io.TransformCTX, embedding):
433+
megatron_config = ctx.target.config
434+
# prune padding.
435+
return embedding[:megatron_config.vocab_size, :]
436+
437+
438+
@io.state_transform(
439+
source_key="output_layer.weight",
440+
target_key="lm_head.weight",
441+
)
442+
def _export_head(ctx: io.TransformCTX, embedding):
443+
megatron_config = ctx.target.config
444+
# prune padding.
445+
return embedding[:megatron_config.vocab_size, :]
446+
447+
429448
@io.state_transform(
430449
source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
431450
target_key="decoder.layers.*.mlp.linear_fc1.weight",
@@ -443,6 +462,15 @@ def _export_linear_fc1(linear_fc1):
443462

444463
return gate_proj, up_proj
445464

465+
# @io.state_transform(
466+
# source_key="decoder.layers.*.mlp.linear_fc1.weight",
467+
# target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"),
468+
# )
469+
# def _export_embedding(linear_fc1):
470+
# gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0)
471+
472+
# return gate_proj, up_proj
473+
446474

447475
def apply_rope_scaling(
448476
inv_freq,

0 commit comments

Comments
 (0)