Skip to content

Commit 08e5fa0

Browse files
committed
add torch_dtype_from_mcore_config
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent b3502cf commit 08e5fa0

File tree

1 file changed

+8
-0
lines changed
  • nemo/collections/llm/gpt/model

1 file changed

+8
-0
lines changed

nemo/collections/llm/gpt/model/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ def default_layer_spec(config: "GPTConfig") -> ModuleSpec:
126126
return local_layer_spec(config)
127127

128128

129+
def torch_dtype_from_mcore_config(config: TransformerConfig):
130+
if config.fp16:
131+
return torch.float16
132+
elif config.bf16:
133+
return torch.bfloat16
134+
else:
135+
return torch.float
136+
129137
@dataclass
130138
class GPTConfig(TransformerConfig, io.IOMixin):
131139
# From megatron.core.models.gpt.gpt_model.GPTModel

0 commit comments

Comments
 (0)