Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from torch import nn

from nemo.collections.llm import fn
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d
from nemo.lightning import get_vocab_size, io
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule
Expand All @@ -31,6 +29,8 @@

def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
from megatron.core import parallel_state
from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import AttnMaskType
from nemo.collections.nlp.modules.common.megatron.utils import build_attention_mask_3d

batch = next(dataloader_iter)

Expand Down
2 changes: 1 addition & 1 deletion nemo/core/optim/mcore_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def sharded_state_dict(
model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
)

def step(self, closure):
def step(self, closure=None):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Apply closure
Expand Down
11 changes: 10 additions & 1 deletion nemo/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import fiddle as fdl
import lightning_fabric as lb
import pytorch_lightning as pl
from torch import nn
from torch.optim import Optimizer

from typing_extensions import Self, override

from nemo.lightning.io.mixin import IOMixin, serialization, track_io
Expand Down Expand Up @@ -130,6 +131,14 @@ def setup_module(self, module: nn.Module, move_to_device: bool = True, _reapply_

return out

def setup_datamodule(self, datamodule: pl.LightningDataModule, stage: str = "") -> pl.LightningDataModule:
datamodule.setup(stage)

if hasattr(self.strategy, "process_datamodule"):
datamodule = self.strategy.process_datamodule(datamodule)

return datamodule


@runtime_checkable
class DistributedModel(Protocol[ModelT]):
Expand Down
5 changes: 3 additions & 2 deletions nemo/lightning/fabric/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def convert_config(self, config: ConfigT) -> ConfigT:
"""Convert the config to the precision type this plugin handles.

This is optional and depends on the precision limitations during optimization.

"""
return update_config_with_dtype_overrides(self.dtype_config, config)

Expand All @@ -122,6 +121,9 @@ def convert_module(self, module: nn.Module) -> nn.Module:
This is optional and depends on the precision limitations during optimization.

"""
if not hasattr(module, "module"):
return module

from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_model_config

Expand All @@ -141,7 +143,6 @@ def convert_optimizer(self, optimizer: Optimizer) -> Optimizer:
"""Convert the optimizer parameters to the precision type this plugin handles.

This is optional and depends on the precision limitations during optimization.

"""
for optim_config in get_optim_config(optimizer):
assert optim_config.bf16 == self.dtype_config.bf16, "BF16 model/optim config mismatch"
Expand Down
46 changes: 37 additions & 9 deletions nemo/lightning/fabric/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning_fabric.utilities.types import _PATH, _Stateful
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from pytorch_lightning import LightningDataModule
from pytorch_lightning.loops.fetchers import _DataFetcher
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.utilities.combined_loader import CombinedLoader
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
if megatron_callbacks:
self.megatron_callbacks.add(megatron_callbacks)
self.output_data_idx = output_data_idx
self.data_sampler: Optional["DataSampler"] = data_sampler

# used in NVIDIA NGC PyTorch containers
_strategy_lib.enable_nvidia_optimizations()
Expand Down Expand Up @@ -141,13 +143,25 @@ def _setup_distributed(self) -> None:
# _strategy_lib.initialize_data(self.cluster_environment.global_rank(), self.data_config)
_strategy_lib.init_model_parallel()

def process_datamodule(self, datamodule: LightningDataModule) -> LightningDataModule:
datamodule.setup()

if not self.data_sampler and hasattr(datamodule, "data_sampler"):
self.data_sampler = datamodule.data_sampler

if self.data_sampler:
self.data_sampler.setup(self.cluster_environment.global_rank())

return datamodule

@override
def process_dataloader(self, dataloader: DataLoader) -> Iterator:
loader = _strategy_lib.process_dataloader(dataloader, self.data_config)
if self.data_sampler:
dataloader = self.data_sampler.transform_dataloader(dataloader)

# Code taken from: https://github.com/Lightning-AI/pytorch-lightning/blob/6cbe9ceb560d798892bdae9186291acf9bf5d2e3/src/lightning/pytorch/loops/fit_loop.py#L258-L260
output = _MegatronDataLoaderIterDataFetcher(self.data_config, output_data_idx=self.output_data_idx)
output.setup(CombinedLoader(loader, "max_size_cycle"))
output = _MegatronDataLoaderIterDataFetcher(output_data_idx=self.output_data_idx)
output.setup(CombinedLoader(dataloader, "max_size_cycle"))
iter(output)

return output
Expand All @@ -160,6 +174,11 @@ def setup_megatron_optimizer(
scale_lr_cond: Optional[Callable] = None,
lr_mult: float = 1.0,
) -> Optimizer:
if hasattr(self.precision, "convert_config"):
optimizer_config = self.precision.convert_config(optimizer_config)

assert optimizer_config.lr is not None, "Learning rate must be set in optimizer config"

return _strategy_lib.setup_megatron_optimizer(
model,
optimizer_config,
Expand All @@ -180,16 +199,23 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:

@override
def setup_module(self, module: Module) -> MegatronParallel:
_strategy_lib.set_model_parallel_attributes(module, self.parallelism)
from megatron.core.utils import get_model_config

# Call configure_model if it's overridden (relevant for LightningModules with lazy initialization)
if hasattr(module, "configure_model"):
module.configure_model()
_strategy_lib.set_model_parallel_attributes(module, self.parallelism)

convert_module_fn = None
if hasattr(self.precision, "convert_module"):
convert_module_fn = self.precision.convert_module

if hasattr(self.precision, "convert_config"):
self.precision.convert_config(get_model_config(module))
if self.ddp_config:
self.precision.convert_config(self.ddp_config)

# Call configure_model if it's overridden (relevant for LightningModules with lazy initialization)
if hasattr(module, "configure_model"):
module.configure_model()

megatron_parallel = MegatronParallel(
module,
precision_plugin=self.precision,
Expand All @@ -202,6 +228,9 @@ def setup_module(self, module: Module) -> MegatronParallel:
if self._init_model_parallel:
megatron_parallel.init_model_parallel()

if self.data_sampler:
megatron_parallel.callbacks.add(self.data_sampler)

if not self.ddp_config:
from megatron.core import mpu

Expand Down Expand Up @@ -364,9 +393,8 @@ def parallelism(self):

# TODO: Fix this
class _MegatronDataLoaderIterDataFetcher(_DataFetcher):
def __init__(self, data_config, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None:
def __init__(self, *args: Any, output_data_idx: bool = False, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.data_config = data_config
self.output_data_idx = output_data_idx
self._batch: Any = None
self._batch_idx: int = 0
Expand Down
18 changes: 17 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def forward(
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
wrap_forward_step: bool = True,
) -> torch.Tensor:
"""The method performs the forward pass of the model.
Expand Down Expand Up @@ -269,6 +270,7 @@ def forward(
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
seq_length=seq_length,
step_i=step_i,
)
_forward_context["step"] = step
step = self.callbacks.transform_event("on_megatron_step_start", step)
Expand Down Expand Up @@ -334,6 +336,7 @@ def validation_step(
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
Expand All @@ -345,6 +348,7 @@ def validation_step(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
Expand All @@ -358,6 +362,7 @@ def test_step(
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
Expand All @@ -369,6 +374,7 @@ def test_step(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
Expand All @@ -382,6 +388,7 @@ def predict_step(
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
return self._step(
Expand All @@ -393,6 +400,7 @@ def predict_step(
seq_length=seq_length,
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
step_i=step_i,
forward_only=True,
**kwargs,
)
Expand All @@ -408,6 +416,7 @@ def _step(
micro_batch_size: Optional[int] = None,
num_microbatches: Optional[int] = None,
forward_only: bool = True,
step_i: Optional[int] = None,
**kwargs,
) -> STEP_OUTPUT:
if not hasattr(self.module, f"{step_type}_step"):
Expand All @@ -426,6 +435,7 @@ def _step(
micro_batch_size=micro_batch_size,
num_microbatches=num_microbatches,
forward_only=forward_only,
step_i=step_i,
**kwargs,
)

Expand Down Expand Up @@ -1043,6 +1053,7 @@ class MegatronStep(Generic[ModelT, DataT]):
micro_batch_size: Optional[int] = None
seq_length: Optional[int] = None
num_microbatches: Optional[int] = None
step_i: Optional[int] = None

@classmethod
def infer(
Expand All @@ -1054,6 +1065,7 @@ def infer(
micro_batch_size: Optional[int] = None,
seq_length: Optional[int] = None,
num_microbatches: Optional[int] = None,
step_i: Optional[int] = None,
) -> "MegatronStep[ModelT, DataT]":
"""
Creates a MegatronStep instance, inferring missing parameters if possible.
Expand All @@ -1069,10 +1081,13 @@ def infer(
micro_batch_size (Optional[int]): Size of each micro-batch.
seq_length (Optional[int]): Sequence length for the current step.
num_microbatches (Optional[int]): Number of micro-batches in this step.

step_i (Optional[int]): Step index for the current step.
Returns:
MegatronStep[ModelT, DataT]: An instance of MegatronStep with inferred parameters.
"""
if step_i is None and pipeline.trainer:
step_i = pipeline.trainer.global_step

return cls(
pipeline=pipeline,
data=data,
Expand All @@ -1081,6 +1096,7 @@ def infer(
micro_batch_size=micro_batch_size or cls.infer_micro_batch_size(data),
seq_length=seq_length or cls.infer_seq_length(data),
num_microbatches=num_microbatches or cls.infer_num_microbatches(data),
step_i=step_i,
)

def __call__(self) -> List[Any]:
Expand Down
38 changes: 21 additions & 17 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.utils import AppState

if not isinstance(self.trainer.strategy, MegatronStrategy):
if not hasattr(self, "trainer") or not isinstance(self.trainer.strategy, MegatronStrategy):
return 0

app_state = AppState()
Expand All @@ -107,6 +107,9 @@ def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
)

def on_megatron_microbatches_start(self, step: MegatronStep) -> None:
if not step.trainer:
return

# do validation and save the checkpoint when gbs is changed
if (
self.rampup_batch_size is not None
Expand All @@ -128,23 +131,24 @@ def on_megatron_step_end(self, step: MegatronStep) -> None:

self.prev_global_batch_size = self.current_global_batch_size

consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
if self.output_log and self.trainer.training:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
if step.step_i:
consumed_samples = self.compute_consumed_samples(step.step_i + 1 - self.init_global_step)
if self.output_log and trainer and getattr(trainer, "training", False):
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
'consumed_samples',
consumed_samples,
prog_bar=True,
batch_size=1,
)

self.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)

self.prev_consumed_samples = consumed_samples

update_num_microbatches(
consumed_samples=consumed_samples,
consistency_check=False,
)
if self.output_log:
if self.output_log and trainer:
# You may need to turn off logging, for example when doing trainer.predict(model, data)
pl_module.log(
"global_batch_size",
Expand Down