Skip to content

BatchSizeFinder leaves model in the train state if used with trainer.validate #18813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
BoringDonut opened this issue Oct 17, 2023 · 2 comments
Closed

Comments

@BoringDonut
Copy link
Contributor

BoringDonut commented Oct 17, 2023

Bug description

If BatchSizeFinder is used with trainer.validate or trainer.test it would leave model in the train state, therefore enabling Dropout and other model randomness. That would influence model predictions and would result in unreliable validation metrics produced.

Step to reproduce:

  1. Add BatchSizeFinder to a trainer
  2. Run validation (without fit)
  3. See how model output changes from one without BatchSizeFinder (can also be influenced by changing random state)

While this doesn't matter for trainer.fit (or LightningCLI fit arg for that matter) it will create undesired randomness when user want to reevaluate a model they already trained

What version are you seeing the problem on?

v1.8, v2.0

How to reproduce the bug

Minimal example with 3 validation outputs:

  1. Without BatchSizeFinder
  2. With default BatchSizeFinder
  3. With BatchSizeFinder that calls to trainer.model.eval()

Note that something like Dropout need to be present in the model to replicate this behavior.

import os

import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BatchSizeFinder
from torch.utils.data import DataLoader, Dataset


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 128)
        self.dropout = torch.nn.Dropout(0.5)
        self.batch_size = 1

    def forward(self, x):
        return self.dropout(self.layer(x))

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


class CustomBatchSizeFinder(BatchSizeFinder):
    def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        super().on_validation_start(trainer, pl_module)
        trainer.model.eval()


def run():
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    print("Without BatchSizeFinder:")
    trainer.validate(model, dataloaders=val_data)

    callbacks = [BatchSizeFinder()]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=callbacks
    )
    print("With BatchSizeFinder:")
    trainer.validate(model, dataloaders=val_data)

    callbacks = [CustomBatchSizeFinder()]
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        callbacks=callbacks
    )
    print("With BatchSizeFinder that calls to `trainer.model.eval()`:")
    trainer.validate(model, dataloaders=val_data)


if __name__ == "__main__":
    torch.manual_seed(1)
    run()

Error messages and logs

Without BatchSizeFinder:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid_loss           15.337783813476562
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
With BatchSizeFinder:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid_loss           26.495624542236328
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
With BatchSizeFinder that calls to `trainer.model.eval()`:
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       valid_loss           15.337783813476562
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

As you can see the loss with BatchSizeFinder differs from other two options
Difference in loss can be fixed either by calling trainer.model.eval() OR by removing randomness from the model (Dropout in that case)

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3050 Laptop GPU
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.1.0
    - lightning-utilities: 0.9.0
    - pytorch-lightning: 2.1.0
    - torch: 2.1.0
    - torchmetrics: 1.2.0
  • Packages:
    - aiohttp: 3.8.6
    - aiosignal: 1.3.1
    - async-timeout: 4.0.3
    - attrs: 23.1.0
    - certifi: 2023.7.22
    - charset-normalizer: 3.3.0
    - filelock: 3.12.4
    - frozenlist: 1.4.0
    - fsspec: 2023.9.2
    - idna: 3.4
    - jinja2: 3.1.2
    - lightning: 2.1.0
    - lightning-utilities: 0.9.0
    - markupsafe: 2.1.3
    - mpmath: 1.3.0
    - multidict: 6.0.4
    - networkx: 3.1
    - numpy: 1.24.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 8.9.2.26
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-nccl-cu12: 2.18.1
    - nvidia-nvjitlink-cu12: 12.2.140
    - nvidia-nvtx-cu12: 12.1.105
    - packaging: 23.2
    - pip: 23.2.1
    - pytorch-lightning: 2.1.0
    - pyyaml: 6.0.1
    - requests: 2.31.0
    - setuptools: 68.1.2
    - sympy: 1.12
    - torch: 2.1.0
    - torchmetrics: 1.2.0
    - tqdm: 4.66.1
    - triton: 2.1.0
    - typing-extensions: 4.8.0
    - urllib3: 2.0.7
    - wheel: 0.41.2
    - yarl: 1.9.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.8.18
    - release: 5.15.0-83-generic
    - version: Update trainer.py #92-Ubuntu SMP Mon Aug 14 09:30:42 UTC 2023

More info

No response

@BoringDonut BoringDonut added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Oct 17, 2023
@tanaymeh
Copy link
Contributor

I would like to fix this bug @BoringDonut

@BoringDonut
Copy link
Contributor Author

BoringDonut commented Oct 18, 2023

I would like to fix this bug @BoringDonut

Thanks @tanaymeh ~~
Feel free to contact me if ya would need any additional testing or details

@awaelchli awaelchli added tuner and removed needs triage Waiting to be triaged by maintainers labels Oct 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants