Skip to content

Commit 117e9bd

Browse files
meatybobbyXuesongYang
authored andcommitted
Change activation parsing in TRTLLM (NVIDIA-NeMo#11173)
* Fix squared relu * Fix openai-gelu
1 parent 1b88175 commit 117e9bd

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
395395
nemo_model_config[k] = v
396396
elif k == "activation_func":
397397
if isinstance(v, torch.jit.ScriptFunction):
398-
nemo_model_config["activation"] = v.name.replace("_", "-")
398+
nemo_model_config["activation"] = v.name
399399
else:
400400
nemo_model_config["activation"] = v.__name__
401401

@@ -405,7 +405,9 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
405405
if nemo_model_config["activation"] == "silu":
406406
nemo_model_config["activation"] = "fast-swiglu"
407407
elif nemo_model_config["activation"] == "openai_gelu":
408-
nemo_model_config["activation"] = "geglu"
408+
nemo_model_config["activation"] = "openai-gelu"
409+
elif nemo_model_config["activation"] == "squared_relu":
410+
nemo_model_config["activation"] = "squared-relu"
409411

410412
nemo_model_config["mcore_gpt"] = True
411413
nemo_model_config["max_position_embeddings"] = nemo_model_config.get("seq_length", 4096)

0 commit comments

Comments
 (0)