-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathmegatron_parallel.py
More file actions
1899 lines (1569 loc) · 74.1 KB
/
megatron_parallel.py
File metadata and controls
1899 lines (1569 loc) · 74.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import abc
import collections.abc
import functools
import inspect
import itertools
import operator
import queue
import types
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
cast,
runtime_checkable,
)
import torch
import torch.distributed
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities import move_data_to_device
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel as McoreDDP
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import Tensor, nn
from typing_extensions import override
from nemo.utils.model_utils import check_lib_version
try:
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel
HAVE_CUSTOM_FSDP = True
except ImportError:
HAVE_CUSTOM_FSDP = False
try:
from megatron.core.distributed import FullyShardedDataParallel
HAVE_MEGATRON_FSDP = True
except ImportError:
HAVE_MEGATRON_FSDP = False
try:
from megatron.core.full_cuda_graph import FullCudaGraphWrapper
HAVE_FULL_CUDA_GRAPH = True
except ImportError:
_, mcore_import_msg = check_lib_version("megatron.core", "0.14.0", operator.ge)
HAVE_FULL_CUDA_GRAPH = False
DataT = TypeVar("DataT", Tensor, Dict[str, Tensor], Sequence[Tensor])
ModelT = TypeVar("ModelT", bound=nn.Module)
T = TypeVar('T')
STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]]
if TYPE_CHECKING:
import lightning.pytorch as pl
@runtime_checkable
class PrecisionPluginProtocol(Protocol[DataT]):
def convert_input(self, data: DataT) -> DataT: ...
def convert_output(self, output: torch.Tensor) -> torch.Tensor: ...
def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT:
"""
Moves the data to a device.
In this case we unpack the dataloader iterator. There may be a wrapper on the dataloader
iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441.
This will not subset the data for your with context parallel so please override this function if you
want to use context parallel.
Examples:
If the dataloader_iter returns: [Tuple[<tensor>, <int>, <int>]] -> move to device
If the dataloader_iter returns: [<tensor>, <tensor>] -> move to device
Returns:
DataT: The data moved to the device.
"""
if parallel_state.get_context_parallel_world_size() > 1:
raise ValueError(
"Default data step is being used in a context parallel environment."
"Please define your own data step that appropriately slices the data for context parallel."
)
batch = next(dataloader_iter)
# If its wrapped in a tuple, unpack it.
if isinstance(batch, tuple) and len(batch) == 3:
batch = batch[0]
return move_data_to_device(batch, torch.cuda.current_device())
def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor:
return model(batch, *args, **kwargs)
def extract_ddp_funcs(ddp_config, pipeline):
no_sync_func, grad_sync_func = None, None
if getattr(ddp_config, "overlap_grad_reduce", False):
no_sync_func = [model_chunk.no_sync for model_chunk in pipeline]
no_sync_func = no_sync_func[0] if len(pipeline) == 1 else no_sync_func
if getattr(ddp_config, "align_grad_reduce", False):
grad_sync_func = [model_chunk.start_grad_sync for model_chunk in pipeline]
grad_sync_func = grad_sync_func[0] if len(pipeline) == 1 else grad_sync_func
return no_sync_func, grad_sync_func
class MegatronParallel(nn.ModuleList, Generic[ModelT]):
"""Implements distributed model parallelism that is based on Megatron-LM.
This supports various forms of parallelism:
- tensor-parallelism
- pipeline-parallelism
- virtual pipeline parallelism
- expert parallelism
- sequence parallelism
Attributes
----------
pipeline (Union[nn.Module, Iterable[nn.Module]]): The sequence of modules that
constitute the pipeline.
precision_plugin (Optional[PrecisionPluginProtocol]): An optional plugin for
managing precision-specific operations.
callbacks (CallbackConnector): A connector for managing and invoking callbacks.
data_step (Callable[[Iterator[DataT]], DataT]): A function that takes an iterator
over the data and returns the next batch.
forward_step (Callable[[nn.Module, DataT], Tensor]): A function that defines the
forward pass of a model.
loss_reduction (Optional[Callable[[nn.Module], MegatronLossReduction]]): An optional
function that defines how the loss is reduced.
vp_size (Optional[int]): Virtual pipeline parallel size.
ddp_config (Optional[DistributedDataParallelConfig]): An instance of Megatron core's
DistributedDataParallelConfig which controls the Megatron DDP configuration.
fsdp (Optional[str]): Whether model should run Torch FSDP2 instead of DDP, select from
["megatron", "torch"]. Defaults to None.
cpu (bool): Whether model should reside on CPU.
convert_module_fn (Optional[Callable[[ModelT], nn.Module]]): An optional function to
apply to the model parameters after initialization.
Examples
--------
>>> from torch import nn
>>> from megatron_ext.megatron_parallel import MegatronParallel
>>> model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5))
>>> megatron_model = MegatronParallel(model)
>>> print(megatron_model)
MegatronParallel(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): ReLU()
(2): Linear(in_features=10, out_features=5, bias=True)
)
References
----------
Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., & Catanzaro, B. (2019).
Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.
arXiv preprint arXiv:1909.08053.
"""
def __init__(
self,
pipeline: Union[ModelT, Iterable[ModelT]],
precision_plugin: Optional[PrecisionPluginProtocol] = None,
callbacks: Optional["CallbackConnector"] = None,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional[Callable[[ModelT], "MegatronLossReduction"]] = None,
vp_size: Optional[int] = None,
ddp_config: Optional[DistributedDataParallelConfig] = None,
fsdp: Optional[str] = None,
cpu: bool = False,
convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
) -> None:
from megatron.core import parallel_state
_pipeline: List[nn.Module]
if isinstance(pipeline, nn.ModuleList):
_pipeline = list(pipeline)
elif isinstance(pipeline, nn.Module):
_pipeline = [pipeline]
else:
_pipeline = pipeline
if vp_size is not None:
if len(_pipeline) == 1 and parallel_state.get_pipeline_model_parallel_world_size() > 1:
from nemo.lightning import io
for i in range(1, vp_size):
_model = io.reinit(_pipeline[0])
if hasattr(_model, "configure_model"):
_model.configure_model(vp_stage=i)
_pipeline.append(_model)
super().__init__(_pipeline)
self.precision_plugin = precision_plugin
self._cpu = cpu
self.callbacks = callbacks or CallbackConnector()
self.data_step = data_step or default_data_step
self.forward_step = forward_step or default_forward_step
self.loss_reduction: MegatronLossReduction = loss_reduction
self.ddp_config = ddp_config
self.fsdp = fsdp
self.convert_module_fn = convert_module_fn
self.vp_size = vp_size
def forward(
self,
data: Union[DataT, Iterator[DataT], List[Iterator[DataT]]],
forward_only: bool = True,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
wrap_forward_step: bool = True,
) -> torch.Tensor:
"""The method performs the forward pass of the model.
This method is responsible for executing the forward pass of the model. If `forward_only` is set to False,
During the execution, it invokes various callbacks at different stages of the operation.
For more info about that see [CallbackConnector].
Args:
data (Union[DataT, Iterator[DataT], List[Iterator[DataT]]]): The input data for the model.
forward_only (bool, optional): If True, only perform the forward pass. Defaults to True.
data_step (Optional[Callable[[Iterator[DataT]], DataT]], optional): Function to process the data.
Defaults to None.
forward_step (Optional[Callable[[nn.Module, DataT], Tensor]], optional): Function to perform the
forward pass. Defaults to None.
loss_reduction (Optional[MegatronLossReduction[DataT, Any]], optional): Function to reduce the
loss. Defaults to None.
seq_length (Optional[int], optional): Sequence length for the model. Defaults to None.
micro_batch_size (Optional[int], optional): Size of the micro batch. Defaults to None.
num_microbatches (Optional[int], optional): Number of microbatches. Defaults to None.
wrap_forward_step (bool, optional): If True, wrap the forward step function. Defaults to True.
Returns
-------
torch.Tensor: The output tensor from the forward pass.
"""
_forward_step = forward_step or self.forward_step
_loss_reduction = loss_reduction or self.loss_reduction
_forward_context = {}
if wrap_forward_step:
_data_step = data_step or self.data_step
forward_step_func = self.wrapped_forward_step(
forward_step=_forward_step,
data_step=_data_step,
loss_reduction=_loss_reduction,
context=_forward_context,
)
else:
forward_step_func = _forward_step
step = MegatronStep.infer(
self,
data,
forward_step_func,
forward_only=forward_only,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
seq_length=seq_length,
step_i=step_i,
)
_forward_context["step"] = step
step = self.callbacks.transform_event("on_megatron_step_start", step)
self.callbacks.event("on_megatron_microbatches_start", step=step)
microbatch_outputs = step()
self.callbacks.event("on_megatron_microbatches_end", step=step, microbatch_outputs=microbatch_outputs)
if microbatch_outputs:
self.callbacks.event(
"on_megatron_reduce_microbatches_start", step=step, microbatch_outputs=microbatch_outputs
)
if isinstance(_loss_reduction, _ModuleStepFunction):
_loss_reduction = _loss_reduction(self.module)
reduced = _loss_reduction.reduce(microbatch_outputs)
self.callbacks.event(
"on_megatron_reduce_microbatches_end",
step=step,
loss_reduction=_loss_reduction,
microbatch_outputs=microbatch_outputs,
reduced=reduced,
)
else:
# we're not on the last pipeline stage so no losses
reduced = torch.tensor(0.0, device=torch.cuda.current_device())
self.callbacks.event("on_megatron_step_end", step=step, microbatch_outputs=microbatch_outputs, reduced=reduced)
return reduced
def training_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"training",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=False,
**kwargs,
)
def validation_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"validation",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
def test_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"test",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
def predict_step(
self,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
"predict",
data,
data_step=data_step,
forward_step=forward_step,
loss_reduction=loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
def _step(
self,
step_type: str,
data: DataT,
data_step: Optional[Callable[[Iterator[DataT]], DataT]] = None,
forward_step: Optional[Callable[[ModelT, DataT], Tensor]] = None,
loss_reduction: Optional["MegatronLossReduction[DataT, Any]"] = None,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
forward_only: bool = True,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
if not hasattr(self.module, f"{step_type}_step"):
raise AttributeError(f"self.module must have a `{step_type}_step` method")
_data_step = data_step or _ModuleStepFunction.from_data_step(self.module, step_type)
_forward_step = forward_step or _ModuleStepFunction.from_forward_step(self.module, step_type)
_loss_reduction = loss_reduction or _ModuleStepFunction.from_loss_reduction(self.module, step_type)
return self.forward(
data=data,
data_step=_data_step,
forward_step=_forward_step,
loss_reduction=_loss_reduction,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=forward_only,
step_i=step_i,
**kwargs,
)
def wrapped_forward_step(
self, forward_step, loss_reduction, data_step, context
) -> Callable[[nn.Module, DataT], Tuple[torch.Tensor, "MegatronCallbackProtocol"]]:
"""The method wraps the forward step function and returns a callable.
The output is a forward_step function in the form of:
https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L129
Args:
forward_step (Callable): The forward step function to be wrapped.
loss_reduction (Callable): The loss reduction function.
context (Dict): The context dictionary.
data_step (Callable): The data step function.
Returns
-------
Callable: The wrapped forward step function.
"""
from megatron.core import parallel_state
@functools.wraps(forward_step)
def wrapped_forward_step_func(dataloader_iter, model):
if isinstance(data_step, _ModuleStepFunction):
_data_step = data_step(model)
else:
_data_step = data_step
batch = _data_step(dataloader_iter)
step = context["step"]
if isinstance(loss_reduction, _ModuleStepFunction):
forward_callback = loss_reduction(model)
else:
forward_callback = loss_reduction
if isinstance(forward_step, _ModuleStepFunction):
_forward_step = forward_step(model)
else:
_forward_step = forward_step
self.callbacks.event(
"on_megatron_microbatch_start",
step=step,
batch=batch,
forward_callback=forward_callback,
)
if self.precision_plugin and parallel_state.is_pipeline_first_stage(
ignore_virtual=False, vp_stage=getattr(model.module, 'vp_stage', None)
):
batch = self.precision_plugin.convert_input(batch)
output_tensor = _forward_step(model, batch)
# callback
self._setup_module(
forward_callback,
batch=batch,
model=self,
forward_module=model,
tensor=output_tensor,
)
if self.precision_plugin and parallel_state.is_pipeline_last_stage(
ignore_virtual=False, vp_stage=getattr(model.module, 'vp_stage', None)
):
output_tensor = self.precision_plugin.convert_output(output_tensor)
self.callbacks.event(
"on_megatron_microbatch_end",
step=step,
batch=batch,
output=output_tensor,
forward_callback=forward_callback,
)
return output_tensor, forward_callback
return wrapped_forward_step_func
def init_model_parallel(self):
from megatron.core import parallel_state
from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
for model_module in self:
if not self._cpu and ((not HAVE_MEGATRON_FSDP and not HAVE_CUSTOM_FSDP) or self.fsdp != "megatron"):
# If Megatron custom FSDP is enabled, we don't need to move the model to GPU here to avoid GPU OOM.
model_module.cuda(torch.cuda.current_device())
for param in model_module.parameters():
set_defaults_if_not_set_tensor_model_parallel_attributes(param)
if hasattr(model_module, "configure_model"):
if not hasattr(model_module, "set_input_tensor"):
if hasattr(model_module.module, "set_input_tensor"):
model_module.set_input_tensor = model_module.module.set_input_tensor
else:
# TODO: What to do here?
pass
# Print number of parameters.
if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
from nemo.utils import logging
num_params = _calc_number_of_params(list(self))
num_trainable_params = _calc_number_of_trainable_params(list(self))
msg = (
f" > number of parameters on (tensor, pipeline) model parallel rank "
f"({parallel_state.get_tensor_model_parallel_rank()} ,"
f"{parallel_state.get_pipeline_model_parallel_rank()}): "
f"{num_params}"
)
logging.info(msg)
if num_params != num_trainable_params:
logging.info(
f" > number of trainable parameters: {num_trainable_params} "
f"({num_trainable_params / num_params:.2%} of total)"
)
if self.convert_module_fn:
self.apply_convert_module_fn()
# Skip init_ddp for inference i.e testing as it can lead to OOM.
try:
if not self.trainer.state.fn == TrainerFn.TESTING:
# DDP initialization is required to be on side-stream to for full iteration CUDA graph.
with torch.cuda.stream(torch.cuda.Stream()):
self.init_ddp()
except RuntimeError as e:
# Don't fail if trainer is not attached, re-raise any other RuntimeError
if "is not attached to a `Trainer`" not in str(e):
raise e
def apply_convert_module_fn(self):
for i in range(len(self)):
self[i] = self.convert_module_fn(self[i])
def init_ddp(self):
if not isinstance(self.ddp_config, DistributedDataParallelConfig):
return
from megatron.core import parallel_state
from megatron.core.transformer.module import Float16Module
from nemo.utils.model_utils import unwrap_model
for model_chunk_idx, model_chunk in enumerate(self):
module = model_chunk.module
# Mcore DistributedDataParallel has to be called with grad. Normally this call is redundant, but for
# PEFT with num_sanity_val_steps > 0 this is necessary.
init_ddp_context = nullcontext if all(x.requires_grad for x in module.parameters()) else torch.enable_grad
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway, or if using VP and overlapping
# data parallel param gather with optimizer
overlap_param_gather_with_optimizer_step = False
if hasattr(self, "optim") and isinstance(self.optim.config, OptimizerConfig):
overlap_param_gather_with_optimizer_step = self.optim.config.overlap_param_gather_with_optimizer_step
disable_bucketing = (model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step
with init_ddp_context():
# Avoid rewrapping the module if it's already wrapped with FSDP
unwrapped_module = unwrap_model(module, Float16Module)
if (
(HAVE_MEGATRON_FSDP or HAVE_CUSTOM_FSDP)
and self.fsdp == "megatron"
and not isinstance(unwrapped_module, FullyShardedDataParallel)
):
from nemo.utils import logging
if not getattr(module.config, "use_megatron_fsdp", False):
setattr(module.config, "use_megatron_fsdp", True)
logging.warning("Setting module.config.use_megatron_fsdp to True for MCore FSDP.")
if not getattr(module.config, "use_custom_fsdp", False):
setattr(module.config, "use_custom_fsdp", True)
logging.warning("Setting module.config.use_custom_fsdp to True for MCore FSDP.")
if getattr(module.config, "gradient_accumulation_fusion", True):
setattr(module.config, "gradient_accumulation_fusion", False)
logging.warning("Setting module.config.gradient_accumulation_fusion to False for MCore FSDP.")
if HAVE_MEGATRON_FSDP:
assert module.config.use_megatron_fsdp, "MCore FSDP is not enabled in module.config."
assert self.ddp_config.use_megatron_fsdp, "MCore FSDP is not enabled in ddp_config."
elif HAVE_CUSTOM_FSDP:
assert module.config.use_custom_fsdp, "MCore FSDP is not enabled in module.config."
assert self.ddp_config.use_custom_fsdp, "MCore FSDP is not enabled in ddp_config."
logging.warning(
"Deprecation Notice: `use_custom_fsdp` will be deprecated in M-Core 0.14. "
"Please use `use_megatron_fsdp` instead."
)
dist_module = FullyShardedDataParallel(
module.config,
self.ddp_config,
module,
disable_bucketing=disable_bucketing,
)
if HAVE_MEGATRON_FSDP:
dist_module.buffers = [dist_module.param_and_grad_buffer]
dist_module.config = module.config
dist_module.sharded_state_dict = lambda *args, **kwargs: dist_module.state_dict()
elif not isinstance(unwrapped_module, DDP):
dist_module = DDP(
module.config,
self.ddp_config,
module,
data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
disable_bucketing=disable_bucketing,
)
else:
dist_module = unwrapped_module
model_chunk.module = dist_module
model_chunk.buffers = (
dist_module.buffers
) # We need to do this explicitly since this is a attr pytorch uses
# save a reference to the original getattr function
# so we can restore the class' getattr during teardown
original_getattr = types.FunctionType(
model_chunk.__getattr__.__code__,
model_chunk.__getattr__.__globals__,
model_chunk.__getattr__.__name__,
model_chunk.__getattr__.__defaults__,
model_chunk.__getattr__.__closure__,
)
model_chunk.original_getattr = original_getattr
model_chunk.original_getattr.__dict__.update(model_chunk.__getattr__.__dict__)
model_chunk.__class__.__getattr__ = getattr_proxy # type: ignore
# param_sync_func is set in nemo.lightning.pytorch.optim.megatron
no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self)
for module in self:
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
def teardown_ddp(self):
for model_chunk in self:
if hasattr(model_chunk, "original_getattr"):
model_chunk.__class__.__getattr__ = model_chunk.original_getattr # type: ignore
def _setup_module(self, function, **kwargs) -> None:
if hasattr(function, "setup"):
setup_args = inspect.getfullargspec(function.setup).args
setup_kwargs = {k: v for k, v in kwargs.items() if k in setup_args}
function.setup(**setup_kwargs)
def _call_module(self, function, *args, **kwargs) -> torch.Tensor:
self._setup_module(function, **kwargs)
call_args = inspect.getfullargspec(function).args
call_kwargs = {k: v for k, v in kwargs.items() if k in call_args}
output_tensor = function(*args, **call_kwargs)
return output_tensor
def sharded_state_dict(self, prefix: str = "", metadata: Optional[dict] = None) -> Dict[str, Any]:
"""
Creates the sharded state dict which is used by dist_checkpoint to save the sharded tensors to disk.
When given the sharded_stated_dict, dist_checkpoint.load will load the tensors corresponding to
self.state_dict().
The sharded tensor mapping is defined in the GPTModel class from mcore.
"""
from nemo.utils import logging
if metadata is None:
metadata = self.trainer.strategy.sharded_state_dict_metadata
logging.debug(
f'No sharded_state_dict metadata passed for the model,'
f' using metadata for checkpoint save: {metadata}'
)
else:
logging.debug(f'Using passed sharded_state_dict metadata in the model: {metadata}')
sharded_state_dict = {}
for index, module in enumerate(self):
if self.vp_size is not None:
module_sharded_state_dict = self._module_sharded_state_dict(module, metadata=metadata)
sharded_state_dict[f"model_{index}"] = module_sharded_state_dict
else:
module_sharded_state_dict = self._module_sharded_state_dict(module, metadata=metadata)
sharded_state_dict.update(module_sharded_state_dict)
return sharded_state_dict
def _module_sharded_state_dict(self, module, *args, **kwargs) -> Dict[str, Any]:
if hasattr(module, "sharded_state_dict"):
return module.sharded_state_dict(*args, **kwargs)
elif hasattr(module, "configure_model"):
prefix = "".join([kwargs.pop("prefix", ""), "module."])
return self._module_sharded_state_dict(module.module, *args, prefix=prefix, **kwargs)
raise ValueError("Could not find sharded state dict")
def enable_forward_pre_hook(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP) or isinstance(model_chunk, FullyShardedDataParallel)
model_chunk.enable_forward_pre_hook()
def disable_forward_pre_hook(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP) or isinstance(model_chunk, FullyShardedDataParallel)
model_chunk.disable_forward_pre_hook()
def force_param_sync(self):
for model in self:
model_chunk = model.module
assert isinstance(model_chunk, DDP) or isinstance(model_chunk, FullyShardedDataParallel)
model_chunk.start_param_sync(force_sync=True)
@property
def pipeline(self) -> Union[ModelT, List[ModelT]]:
if len(self) == 1:
return self[0]
else:
return list(self)
@property
def module(self) -> ModelT:
return self[0]
@override
def __getattr__(self, item: Any) -> Any:
try:
# First, try to get the attribute from the superclass (nn.ModuleList)
return super().__getattr__(item)
except AttributeError:
# If not found in superclass, check if we have any modules
if len(self) == 0:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
)
# Try to get it from the first module
try:
return getattr(self._modules[self._get_abs_string_index(0)], item)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
class _ModuleStepFunction:
"""
This class acts as a bridge between Megatron core's lower-level functional API and PTL's object-oriented API,
making it possible to use PTL-compatible functions in Megatron core.
"""
def __init__(self, name: str, is_property: bool = False, includes_self: bool = False):
self.name = name
self.is_property = is_property
self.includes_self = includes_self
@classmethod
def from_data_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_data_step", "data_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name)
return None
@classmethod
def from_forward_step(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
from megatron.core import parallel_state
if parallel_state.is_pipeline_last_stage(ignore_virtual=False, vp_stage=getattr(module, 'vp_stage', None)):
if not hasattr(module, f"{step_type}_step"):
raise ValueError(f"LightningModule does not have {step_type}_step method")
return _ModuleStepFunction(f"{step_type}_step", includes_self=True)
for fn_name in [f"{step_type}_forward_step", "forward_step"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, includes_self=True)
return None
@classmethod
def from_loss_reduction(cls, module: "pl.LightningModule", step_type: str) -> Optional["_ModuleStepFunction"]:
for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]:
if hasattr(module, fn_name):
return _ModuleStepFunction(fn_name, is_property=True)
return None
def __call__(self, module: nn.Module):
attr = getattr(module, self.name)
if self.is_property:
if isinstance(getattr(type(module), self.name), property):
return attr
else:
return attr()
if self.includes_self:
def wrapped(self, *args):
return attr(*args)
return wrapped
return attr
def getattr_proxy(self, item: Any) -> Any:
try:
return super(self.__class__, self).__getattr__(item)
except AttributeError as e:
if item == 'module': ## this is a hacky WAR and may cause misleading error messages
raise e
try:
return getattr(self.module, item)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
class DDP(McoreDDP):
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
**kwargs,
):
init_parameters = inspect.signature(McoreDDP.__init__).parameters
# Updates to the McoreDDP class have removed some parameters, so we need to
# filter out any kwargs that are not part of the updated signature, if a new
# version of mcore is being used.
filtered_kwargs = {k: v for k, v in kwargs.items() if k in init_parameters}
super().__init__(
config=config,
ddp_config=ddp_config,
module=module,
disable_bucketing=disable_bucketing,
**filtered_kwargs,
)
def state_dict(self, prefix='', keep_vars=False, **kwargs):
self.module.state_dict(prefix=prefix, keep_vars=keep_vars, **kwargs)
def __getattr__(self, item: Any) -> Any:
return getattr_proxy(self, item)
class CallbackConnector:
"""
A connector for managing and invoking callbacks.
The CallbackConnector class in the MegatronParallel module
is used to manage and invoke callbacks during the execution of the model.
Callbacks are functions that are called at specific stages of the model
execution, allowing you to hook into the model's operation for logging, debugging, or other purposes.
The CallbackMethods class defines the names of the callback methods that can be used.
These methods are:
- `on_megatron_step_start`
- `on_megatron_microbatch_start`
- `on_megatron_microbatch_callback`
- `on_megatron_microbatch_end`
- `on_megatron_reduce_microbatches_start`
- `on_megatron_reduce_microbatches_end`
- `on_megatron_log_step_end`
- `on_megatron_step_end`
Each of these methods corresponds to a specific stage in the model's operation.
You can define these methods in your callback functions to perform specific actions at these stages.
There is no need for the class to be a subclass of a specific parent class.
As long as the class contains the methods outlined above, it can be used as a callback.
"""
def __init__(self, callbacks=None) -> None:
self.callbacks = defaultdict(list)
if callbacks:
self.add(*callbacks)
def add(self, *callbacks) -> "CallbackConnector":
"""
Adds callback functions to the connector.
Parameters
----------
*callbacks : CallbackT
One or more callback functions to add.
Returns
-------
CallbackConnector
The CallbackConnector instance to allow method chaining.
"""
_pl_callback = None
try:
import lightning.pytorch as pl
_pl_callback = pl.Callback
except ImportError:
pass
megatron_methods = {m for m in dir(CallbackMethods) if m.startswith("on") and not hasattr(_pl_callback, m)}
for callback in callbacks:
if isinstance(callback, CallbackConnector):
# Handle CallbackConnector instance: merge its callbacks
for event_name, event_callbacks in callback.callbacks.items():
self.callbacks[event_name].extend(event_callbacks)
else:
for method in megatron_methods:
if hasattr(callback, method) and callable(getattr(callback, method)):
self.callbacks[method].append(callback)
return self
def event(self, name: str, *args, **kwargs) -> None:
"""
Triggers an event and calls all associated callbacks.