Skip to content

Commit 3e0a4d8

Browse files
marcromeynakoumpa
authored andcommitted
Adding support for LightningDataModule inside Fabric-API (#10879)
* Make FabricMegatronMixedPrecision match MegatronMixedPrecision Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> * Apply isort and black reformatting Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> * Supporting DataModule in fabric-API Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> * Adding support for LightningDataModule inside Fabric-API Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> * Apply isort and black reformatting Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> * Remove import in mock.py Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> --------- Signed-off-by: Marc Romeijn <mromeijn@nvidia.com> Signed-off-by: marcromeyn <marcromeyn@users.noreply.github.com> Co-authored-by: marcromeyn <marcromeyn@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent b22664b commit 3e0a4d8

File tree

7 files changed

+91
-33
lines changed

7 files changed

+91
-33
lines changed

nemo/collections/llm/t5/model/t5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from torch import nn
1212

1313
from nemo.collections.llm import fn
14-
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
15-
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d
1614
from nemo.lightning import get_vocab_size, io
1715
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
1816
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule
@@ -31,6 +29,8 @@
3129

3230
def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
3331
from megatron.core import parallel_state
32+
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
33+
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d
3434

3535
batch = next(dataloader_iter)
3636

nemo/core/optim/mcore_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def sharded_state_dict(
6161
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
6262
)
6363

64-
def step(self, closure):
64+
def step(self, closure=None):
6565
"""Clip gradients (if needed) and step the base optimizer.
6666
Always return successful since there is no overflow."""
6767
# Apply closure

nemo/lightning/fabric/fabric.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import fiddle as fdl
66
import lightning_fabric as lb
7+
import pytorch_lightning as pl
78
from torch import nn
8-
from torch.optim import Optimizer
9+
910
from typing_extensions import Self, override
1011

1112
from nemo.lightning.io.mixin import IOMixin, serialization, track_io
@@ -130,6 +131,14 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True, _reapply_
130131

131132
return out
132133

134+
def setup_datamodule(self, datamodule: pl.LightningDataModule, stage: str = "") -> pl.LightningDataModule:
135+
datamodule.setup(stage)
136+
137+
if hasattr(self.strategy, "process_datamodule"):
138+
datamodule = self.strategy.process_datamodule(datamodule)
139+
140+
return datamodule
141+
133142

134143
@runtime_checkable
135144
class DistributedModel(Protocol[ModelT]):

nemo/lightning/fabric/plugins.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def convert_config(self, config: ConfigT) -> ConfigT:
112112
"""Convert the config to the precision type this plugin handles.
113113
114114
This is optional and depends on the precision limitations during optimization.
115-
116115
"""
117116
return update_config_with_dtype_overrides(self.dtype_config, config)
118117

@@ -122,6 +121,9 @@ def convert_module(self, module: nn.Module) -> nn.Module:
122121
This is optional and depends on the precision limitations during optimization.
123122
124123
"""
124+
if not hasattr(module, "module"):
125+
return module
126+
125127
from megatron.core.transformer.module import Float16Module
126128
from megatron.core.utils import get_model_config
127129

@@ -141,7 +143,6 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
141143
"""Convert the optimizer parameters to the precision type this plugin handles.
142144
143145
This is optional and depends on the precision limitations during optimization.
144-
145146
"""
146147
for optim_config in get_optim_config(optimizer):
147148
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch"

nemo/lightning/fabric/strategies.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightning_fabric.utilities.types import _PATH, _Stateful
2727
from megatron.core.distributed import DistributedDataParallelConfig
2828
from megatron.core.optimizer import OptimizerConfig
29+
from pytorch_lightning import LightningDataModule
2930
from pytorch_lightning.loops.fetchers import _DataFetcher
3031
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
3132
from pytorch_lightning.utilities.combined_loader import CombinedLoader
@@ -106,6 +107,7 @@ def __init__(
106107
if megatron_callbacks:
107108
self.megatron_callbacks.add(megatron_callbacks)
108109
self.output_data_idx = output_data_idx
110+
self.data_sampler: Optional["DataSampler"] = data_sampler
109111

110112
# used in NVIDIA NGC PyTorch containers
111113
_strategy_lib.enable_nvidia_optimizations()
@@ -141,13 +143,25 @@ def _setup_distributed(self) -> None:
141143
# _strategy_lib.initialize_data(self.cluster_environment.global_rank(), self.data_config)
142144
_strategy_lib.init_model_parallel()
143145

146+
def process_datamodule(self, datamodule: LightningDataModule) -> LightningDataModule:
147+
datamodule.setup()
148+
149+
if not self.data_sampler and hasattr(datamodule, "data_sampler"):
150+
self.data_sampler = datamodule.data_sampler
151+
152+
if self.data_sampler:
153+
self.data_sampler.setup(self.cluster_environment.global_rank())
154+
155+
return datamodule
156+
144157
@override
145158
def process_dataloader(self, dataloader: DataLoader) -> Iterator:
146-
loader = _strategy_lib.process_dataloader(dataloader, self.data_config)
159+
if self.data_sampler:
160+
dataloader = self.data_sampler.transform_dataloader(dataloader)
147161

148162
# Code taken from: https://github.com/Lightning-AI/pytorch-lightning/blob/6cbe9ceb560d798892bdae9186291acf9bf5d2e3/src/lightning/pytorch/loops/fit_loop.py#L258-L260
149-
output = _MegatronDataLoaderIterDataFetcher(self.data_config, output_data_idx=self.output_data_idx)
150-
output.setup(CombinedLoader(loader, "max_size_cycle"))
163+
output = _MegatronDataLoaderIterDataFetcher(output_data_idx=self.output_data_idx)
164+
output.setup(CombinedLoader(dataloader, "max_size_cycle"))
151165
iter(output)
152166

153167
return output
@@ -160,6 +174,11 @@ def setup_megatron_optimizer(
160174
scale_lr_cond: Optional[Callable] = None,
161175
lr_mult: float = 1.0,
162176
) -> Optimizer:
177+
if hasattr(self.precision, "convert_config"):
178+
optimizer_config = self.precision.convert_config(optimizer_config)
179+
180+
assert optimizer_config.lr is not None, "Learning rate must be set in optimizer config"
181+
163182
return _strategy_lib.setup_megatron_optimizer(
164183
model,
165184
optimizer_config,
@@ -180,16 +199,23 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
180199

181200
@override
182201
def setup_module(self, module: Module) -> MegatronParallel:
183-
_strategy_lib.set_model_parallel_attributes(module, self.parallelism)
202+
from megatron.core.utils import get_model_config
184203

185-
# Call configure_model if it's overridden (relevant for LightningModules with lazy initialization)
186-
if hasattr(module, "configure_model"):
187-
module.configure_model()
204+
_strategy_lib.set_model_parallel_attributes(module, self.parallelism)
188205

189206
convert_module_fn = None
190207
if hasattr(self.precision, "convert_module"):
191208
convert_module_fn = self.precision.convert_module
192209

210+
if hasattr(self.precision, "convert_config"):
211+
self.precision.convert_config(get_model_config(module))
212+
if self.ddp_config:
213+
self.precision.convert_config(self.ddp_config)
214+
215+
# Call configure_model if it's overridden (relevant for LightningModules with lazy initialization)
216+
if hasattr(module, "configure_model"):
217+
module.configure_model()
218+
193219
megatron_parallel = MegatronParallel(
194220
module,
195221
precision_plugin=self.precision,
@@ -202,6 +228,9 @@ def setup_module(self, module: Module) -> MegatronParallel:
202228
if self._init_model_parallel:
203229
megatron_parallel.init_model_parallel()
204230

231+
if self.data_sampler:
232+
megatron_parallel.callbacks.add(self.data_sampler)
233+
205234
if not self.ddp_config:
206235
from megatron.core import mpu
207236

@@ -364,9 +393,8 @@ def parallelism(self):
364393

365394
# TODO: Fix this
366395
class _MegatronDataLoaderIterDataFetcher(_DataFetcher):
367-
def __init__(self, data_config, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None:
396+
def __init__(self, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None:
368397
super().__init__(*args, **kwargs)
369-
self.data_config = data_config
370398
self.output_data_idx = output_data_idx
371399
self._batch: Any = None
372400
self._batch_idx: int = 0

nemo/lightning/megatron_parallel.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def forward(
222222
seq_length: Optional[int] = None,
223223
micro_batch_size: Optional[int] = None,
224224
num_microbatches: Optional[int] = None,
225+
step_i: Optional[int] = None,
225226
wrap_forward_step: bool = True,
226227
) -> torch.Tensor:
227228
"""The method performs the forward pass of the model.
@@ -269,6 +270,7 @@ def forward(
269270
micro_batch_size=micro_batch_size,
270271
num_microbatches=num_microbatches,
271272
seq_length=seq_length,
273+
step_i=step_i,
272274
)
273275
_forward_context["step"] = step
274276
step = self.callbacks.transform_event("on_megatron_step_start", step)
@@ -334,6 +336,7 @@ def validation_step(
334336
seq_length: Optional[int] = None,
335337
micro_batch_size: Optional[int] = None,
336338
num_microbatches: Optional[int] = None,
339+
step_i: Optional[int] = None,
337340
**kwargs,
338341
) -> STEP_OUTPUT:
339342
return self._step(
@@ -345,6 +348,7 @@ def validation_step(
345348
seq_length=seq_length,
346349
micro_batch_size=micro_batch_size,
347350
num_microbatches=num_microbatches,
351+
step_i=step_i,
348352
forward_only=True,
349353
**kwargs,
350354
)
@@ -358,6 +362,7 @@ def test_step(
358362
seq_length: Optional[int] = None,
359363
micro_batch_size: Optional[int] = None,
360364
num_microbatches: Optional[int] = None,
365+
step_i: Optional[int] = None,
361366
**kwargs,
362367
) -> STEP_OUTPUT:
363368
return self._step(
@@ -369,6 +374,7 @@ def test_step(
369374
seq_length=seq_length,
370375
micro_batch_size=micro_batch_size,
371376
num_microbatches=num_microbatches,
377+
step_i=step_i,
372378
forward_only=True,
373379
**kwargs,
374380
)
@@ -382,6 +388,7 @@ def predict_step(
382388
seq_length: Optional[int] = None,
383389
micro_batch_size: Optional[int] = None,
384390
num_microbatches: Optional[int] = None,
391+
step_i: Optional[int] = None,
385392
**kwargs,
386393
) -> STEP_OUTPUT:
387394
return self._step(
@@ -393,6 +400,7 @@ def predict_step(
393400
seq_length=seq_length,
394401
micro_batch_size=micro_batch_size,
395402
num_microbatches=num_microbatches,
403+
step_i=step_i,
396404
forward_only=True,
397405
**kwargs,
398406
)
@@ -408,6 +416,7 @@ def _step(
408416
micro_batch_size: Optional[int] = None,
409417
num_microbatches: Optional[int] = None,
410418
forward_only: bool = True,
419+
step_i: Optional[int] = None,
411420
**kwargs,
412421
) -> STEP_OUTPUT:
413422
if not hasattr(self.module, f"{step_type}_step"):
@@ -426,6 +435,7 @@ def _step(
426435
micro_batch_size=micro_batch_size,
427436
num_microbatches=num_microbatches,
428437
forward_only=forward_only,
438+
step_i=step_i,
429439
**kwargs,
430440
)
431441

@@ -1043,6 +1053,7 @@ class MegatronStep(Generic[ModelT, DataT]):
10431053
micro_batch_size: Optional[int] = None
10441054
seq_length: Optional[int] = None
10451055
num_microbatches: Optional[int] = None
1056+
step_i: Optional[int] = None
10461057

10471058
@classmethod
10481059
def infer(
@@ -1054,6 +1065,7 @@ def infer(
10541065
micro_batch_size: Optional[int] = None,
10551066
seq_length: Optional[int] = None,
10561067
num_microbatches: Optional[int] = None,
1068+
step_i: Optional[int] = None,
10571069
) -> "MegatronStep[ModelT, DataT]":
10581070
"""
10591071
Creates a MegatronStep instance, inferring missing parameters if possible.
@@ -1069,10 +1081,13 @@ def infer(
10691081
micro_batch_size (Optional[int]): Size of each micro-batch.
10701082
seq_length (Optional[int]): Sequence length for the current step.
10711083
num_microbatches (Optional[int]): Number of micro-batches in this step.
1072-
1084+
step_i (Optional[int]): Step index for the current step.
10731085
Returns:
10741086
MegatronStep[ModelT, DataT]: An instance of MegatronStep with inferred parameters.
10751087
"""
1088+
if step_i is None and pipeline.trainer:
1089+
step_i = pipeline.trainer.global_step
1090+
10761091
return cls(
10771092
pipeline=pipeline,
10781093
data=data,
@@ -1081,6 +1096,7 @@ def infer(
10811096
micro_batch_size=micro_batch_size or cls.infer_micro_batch_size(data),
10821097
seq_length=seq_length or cls.infer_seq_length(data),
10831098
num_microbatches=num_microbatches or cls.infer_num_microbatches(data),
1099+
step_i=step_i,
10841100
)
10851101

10861102
def __call__(self) -> List[Any]:

nemo/lightning/pytorch/plugins/data_sampler.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
8282
from nemo.lightning.pytorch.strategies import MegatronStrategy
8383
from nemo.utils import AppState
8484

85-
if not isinstance(self.trainer.strategy, MegatronStrategy):
85+
if not hasattr(self, "trainer") or not isinstance(self.trainer.strategy, MegatronStrategy):
8686
return 0
8787

8888
app_state = AppState()
@@ -107,6 +107,9 @@ def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
107107
)
108108

109109
def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
110+
if not step.trainer:
111+
return
112+
110113
# do validation and save the checkpoint when gbs is changed
111114
if (
112115
self.rampup_batch_size is not None
@@ -128,23 +131,24 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:
128131

129132
self.prev_global_batch_size = self.current_global_batch_size
130133

131-
consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
132-
if self.output_log and self.trainer.training:
133-
# You may need to turn off logging, for example when doing trainer.predict(model, data)
134-
pl_module.log(
135-
'consumed_samples',
136-
consumed_samples,
137-
prog_bar=True,
138-
batch_size=1,
134+
if step.step_i:
135+
consumed_samples = self.compute_consumed_samples(step.step_i + 1 - self.init_global_step)
136+
if self.output_log and trainer and getattr(trainer, "training", False):
137+
# You may need to turn off logging, for example when doing trainer.predict(model, data)
138+
pl_module.log(
139+
'consumed_samples',
140+
consumed_samples,
141+
prog_bar=True,
142+
batch_size=1,
143+
)
144+
145+
self.prev_consumed_samples = consumed_samples
146+
147+
update_num_microbatches(
148+
consumed_samples=consumed_samples,
149+
consistency_check=False,
139150
)
140-
141-
self.prev_consumed_samples = consumed_samples
142-
143-
update_num_microbatches(
144-
consumed_samples=consumed_samples,
145-
consistency_check=False,
146-
)
147-
if self.output_log:
151+
if self.output_log and trainer:
148152
# You may need to turn off logging, for example when doing trainer.predict(model, data)
149153
pl_module.log(
150154
"global_batch_size",

0 commit comments

Comments
 (0)