-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Closed
Labels
Description
System Info
latest main
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
num_examples is the number of examples in the dataloader and is not the number of batches in the dataloader, and is not the number of batches for one process. so the remainer is incorrect.
when steps_in_epoch % args.gradient_accumulation_steps == 0 and args.gradient_accumulation_steps > 1, total_updates is one more than expected.
transformers/src/transformers/trainer.py
Lines 2497 to 2503 in 7bb619d
| remainder = num_examples % args.gradient_accumulation_steps | |
| if remainder == 0: | |
| remainder = args.gradient_accumulation_steps | |
| update_step = -1 | |
| total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 | |
| if args.gradient_accumulation_steps == 1: | |
| total_updates -= 1 |
num_examples from:
transformers/src/transformers/trainer.py
Lines 1756 to 1768 in 7bb619d
| def num_examples(self, dataloader: DataLoader) -> int: | |
| """ | |
| Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When | |
| dataloader.dataset does not exist or has no length, estimates as best it can | |
| """ | |
| try: | |
| dataset = dataloader.dataset | |
| # Special case for IterableDatasetShard, we need to dig deeper | |
| if isinstance(dataset, IterableDatasetShard): | |
| return len(dataloader.dataset.dataset) | |
| return len(dataloader.dataset) | |
| except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader | |
| return len(dataloader) * self.args.per_device_train_batch_size |
Expected behavior
correct code is
remainder = steps_in_epoch % args.gradient_accumulation_steps
if remainder == 0:
remainder = args.gradient_accumulation_steps
update_step = -1
total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(remainder < args.gradient_accumulation_steps)