Skip to content

Commit 1343bee

Browse files
committed
add nemotron5 conversion
1 parent 76c91b7 commit 1343bee

File tree

3 files changed

+297
-17
lines changed

3 files changed

+297
-17
lines changed

examples/nlp/language_modeling/conf/megatron_mamba_config.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ model:
7676
post_process: True # add pooler
7777
megatron_legacy: False
7878
persist_layer_norm: True
79-
79+
squared_relu_activation: True
80+
params_dtype: bf16
8081
tokenizer:
8182
library: 'huggingface'
8283
type: 'EleutherAI/gpt-neox-20b'
@@ -87,7 +88,7 @@ model:
8788
use_fast: True
8889

8990
# Distributed checkpoint setup
90-
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
91+
dist_ckpt_format: 'torch_dist' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
9192
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
9293
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint
9394

nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
import torch.nn.functional as F
1617
from omegaconf.dictconfig import DictConfig
1718
from pytorch_lightning.trainer.trainer import Trainer
1819

@@ -38,6 +39,8 @@
3839

3940
HAVE_MEGATRON_CORE = False
4041

42+
def squared_relu(x):
43+
return torch.pow(F.relu(x), 2)
4144

4245
class MegatronMambaModel(MegatronGPTModel):
4346
"""
@@ -62,6 +65,15 @@ def model_provider_func(self, pre_process, post_process):
6265
self.transformer_config.add_bias_linear = self.cfg.get('add_bias_linear', False)
6366
self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False)
6467
self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5)
68+
if self.cfg.get('params_dtype'):
69+
self.transformer_config.params_dtype = torch.bfloat16
70+
else:
71+
self.transformer_config.params_dtype = torch.float32
72+
self.transformer_config.params_dtype=torch.bfloat16
73+
if self.cfg.get('kv_channels'):
74+
self.transformer_config.kv_channels = self.cfg.get('kv_channels')
75+
if self.cfg.get('squared_relu_activation'):
76+
self.transformer_config.activation_func = squared_relu
6577

6678
model = MambaModel(
6779
config=self.transformer_config,

0 commit comments

Comments
 (0)