|
16 | 16 | import re |
17 | 17 | from argparse import ArgumentParser |
18 | 18 | from collections import defaultdict |
19 | | - |
20 | 19 | import torch |
21 | 20 | import torch.distributed as dist |
22 | 21 | from megatron.core.dist_checkpointing.serialization import load_plain_tensors |
|
25 | 24 | from megatron.core.transformer.spec_utils import import_module |
26 | 25 | from megatron.training.arguments import core_transformer_config_from_args |
27 | 26 | from omegaconf.omegaconf import OmegaConf |
28 | | - |
| 27 | +from datetime import timedelta |
| 28 | +import megatron.core.parallel_state as ps |
| 29 | +from torch._C._distributed_c10d import PrefixStore |
| 30 | +from torch.distributed import rendezvous |
29 | 31 | from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel |
30 | 32 | from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder |
31 | 33 | from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision |
@@ -92,31 +94,6 @@ def get_args(): |
92 | 94 | args = parser.parse_args() |
93 | 95 | return args |
94 | 96 |
|
95 | | - |
96 | | -import os |
97 | | -from datetime import timedelta |
98 | | - |
99 | | -import megatron.core.parallel_state as ps |
100 | | -import torch |
101 | | -from torch._C._distributed_c10d import PrefixStore |
102 | | -from torch.distributed import rendezvous |
103 | | - |
104 | | - |
105 | | -class TestModel(torch.nn.Module): |
106 | | - def __init__( |
107 | | - self, |
108 | | - input_dim: int, |
109 | | - output_dim: int, |
110 | | - num_layers: int, |
111 | | - bias: bool, |
112 | | - shared_embedding: bool = False, |
113 | | - ): |
114 | | - super().__init__() |
115 | | - self.layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, output_dim, bias) for _ in range(num_layers)]) |
116 | | - if shared_embedding: |
117 | | - self.layers[-1].weight.shared_embedding = True |
118 | | - |
119 | | - |
120 | 97 | try: |
121 | 98 |
|
122 | 99 | class Utils: |
|
0 commit comments