Skip to content

Commit 050d600

Browse files
authored
Bound number of workers by number of datasets (#157)
*Issue #, if available:* Fixes #154 *Description of changes:* Prior to the fix, some workers have no dataset to consume if `dataloader_num_workers > len(training_data_paths)`. By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
1 parent 9d59057 commit 050d600

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

scripts/training/train.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,16 @@ def main(
569569
probability = [1.0 / len(training_data_paths)] * len(training_data_paths)
570570
assert isinstance(probability, list)
571571

572+
assert len(training_data_paths) == len(probability)
573+
574+
if dataloader_num_workers > len(training_data_paths):
575+
log_on_main(
576+
f"Setting the number of data loader workers to {len(training_data_paths)}, "
577+
f"instead of {dataloader_num_workers}.",
578+
logger,
579+
)
580+
dataloader_num_workers = len(training_data_paths)
581+
572582
if isinstance(tokenizer_kwargs, str):
573583
tokenizer_kwargs = ast.literal_eval(tokenizer_kwargs)
574584
assert isinstance(tokenizer_kwargs, dict)

0 commit comments

Comments
 (0)