diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 8e119e24438db..c443587c4ce9f 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -36,6 +36,7 @@ from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars from lightning.fabric.utilities.cloud_io import _is_local_file_protocol from lightning.fabric.utilities.types import _PATH +from lightning.fabric.wrappers import _to_compiled, _unwrap_compiled from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBar from lightning.pytorch.core.datamodule import LightningDataModule @@ -536,19 +537,25 @@ def fit( For more information about multiple dataloaders, see this :ref:`section `. """ - model = _maybe_unwrap_optimized(model) + # when provided compiled model, unwrap and re-do after applied strategy + model, compile_kwargs = ( + _unwrap_compiled(model) + if isinstance(model, torch._dynamo.OptimizedModule) + else (_maybe_unwrap_optimized(model), None) + ) self.strategy._lightning_module = model _verify_strategy_supports_compile(model, self.strategy) self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True call._call_and_handle_interrupt( - self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path + self, self._fit_impl, model, compile_kwargs, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) def _fit_impl( self, model: "pl.LightningModule", + compile_kwargs: Optional[dict[str, Any]] = None, train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, val_dataloaders: Optional[EVAL_DATALOADERS] = None, datamodule: Optional[LightningDataModule] = None, @@ -578,7 +585,7 @@ def _fit_impl( model_provided=True, model_connected=self.lightning_module is not None, ) - self._run(model, ckpt_path=ckpt_path) + self._run(model, compile_kwargs, ckpt_path=ckpt_path) assert self.state.stopped self.training = False @@ -909,7 +916,10 @@ def _predict_impl( return results def _run( - self, model: "pl.LightningModule", ckpt_path: Optional[_PATH] = None + self, + model: "pl.LightningModule", + compile_kwargs: Optional[dict[str, Any]] = None, + ckpt_path: Optional[_PATH] = None, ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( @@ -963,6 +973,10 @@ def _run( # strategy will configure model and move it to the device self.strategy.setup(self) + # when provided compiled model, unwrap is done in fit method, re-apply compile after applying strategy + if compile_kwargs is not None: + self.strategy.model = _to_compiled(self.strategy.model, compile_kwargs) + # hook if self.state.fn == TrainerFn.FITTING: call._call_callback_hooks(self, "on_fit_start") diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 048403366ebc7..bd57674782674 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -17,6 +17,7 @@ import pytest import torch +from torch._dynamo import OptimizedModule from torch.distributed.optim import ZeroRedundancyOptimizer from torch.multiprocessing import ProcessRaisedException from torch.nn.parallel.distributed import DistributedDataParallel @@ -448,3 +449,30 @@ def creates_processes_externally(self): RuntimeError, match="Lightning attempted to launch new distributed processes with `local_rank > 0`." ): trainer.fit(model) + + +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(): + """Test that Trainer can rewrap a compiled module such that compilation happens over the DDP-wrapper.""" + trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp", max_steps=2, logger=False) + + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + trainer.fit(compiled_model) + trainer_model = trainer.strategy.model + + assert isinstance(trainer_model, OptimizedModule) + assert isinstance(trainer_model._orig_mod, DistributedDataParallel) + # Assert we called compile again with the same arguments, but on the DDP-wrapped module + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) + + assert trainer_model._orig_mod.module == model + + # Smoke-testing forward to ensure we don't get compilation errors + for _ in range(3): + trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index f3e88ca356764..14a2e76b55ef7 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -12,6 +12,7 @@ import pytest import torch import torch.nn as nn +from torch._dynamo import OptimizedModule from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap from torchmetrics import Accuracy @@ -971,3 +972,30 @@ def configure_optimizers(self): max_steps=4, ) trainer.fit(model, ckpt_path=checkpoint_path_full) + + +@RunIf(min_cuda_gpus=2, standalone=True, dynamo=True) +@mock.patch("lightning.fabric.wrappers.torch.compile", Mock(wraps=torch.compile)) +@mock.patch.dict(os.environ, {}) +def test_reapply_compile(): + """Test that Trainer can rewrap a compiled module such that compilation happens over the FSDP-wrapper.""" + trainer = Trainer(accelerator="gpu", devices=2, strategy="fsdp", max_steps=2, logger=False) + + model = BoringModel() + compile_kwargs = {"mode": "reduce-overhead"} + compiled_model = torch.compile(model, **compile_kwargs) + torch.compile.reset_mock() + + trainer.fit(compiled_model) + trainer_model = trainer.strategy.model + + assert isinstance(trainer_model, OptimizedModule) + assert isinstance(trainer_model._orig_mod, FullyShardedDataParallel) + # Assert we called compile again with the same arguments, but on the FSDP-wrapped module + torch.compile.assert_called_with(trainer_model._orig_mod, **compile_kwargs) + + assert trainer_model._orig_mod.module == model + + # Smoke-testing forward to ensure we don't get compilation errors + for _ in range(3): + trainer_model(torch.randn(2, 32, device="gpu")).sum().backward() diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index a053c847dfd6c..995484f13f30d 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -46,18 +46,14 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): model = BoringModel() compiled_model = torch.compile(model) - assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference # can train with compiled model trainer = Trainer(**trainer_kwargs) trainer.fit(compiled_model) - assert trainer.model._compiler_ctx["compiler"] == "dynamo" + assert isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule) # the compiled model can be uncompiled to_uncompiled_model = to_uncompiled(compiled_model) - assert model._compiler_ctx is None - assert compiled_model._compiler_ctx is None - assert to_uncompiled_model._compiler_ctx is None # the compiled model needs to be passed with pytest.raises(ValueError, match="required to be a compiled LightningModule"): @@ -66,7 +62,7 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): # the uncompiled model can be fitted trainer = Trainer(**trainer_kwargs) trainer.fit(model) - assert trainer.model._compiler_ctx is None + assert not isinstance(trainer.strategy.model, torch._dynamo.OptimizedModule) # some strategies do not support it if RequirementCache("deepspeed"):