Skip to content

Commit 4edc688

Browse files
Only load data on main process (huggingface#1255)
* fix: only load data on main process * define is_main_process once Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * avoid re-initializing PartialState on train dataset check Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * avoid re-initializing PartialState on eval dataset check Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * process dataset on main first to take advantage of caching * fix typo in docs * use decorator to manage state * Revert "fix typo in docs" This reverts commit 0880a188812a698f7106853245ce1ba96a036831. * Revert "Revert "fix typo in docs"" This reverts commit ff7ee33fbeedcd0032b728d86a17cfcb10e43f9b. * Revert "use decorator to manage state" This reverts commit 7ac7a45949f621941fedc522f0d2ca7b29367c3a. * use is_local_main_process instead of is_main_process * fix: use context manager instead of attribute Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update trl/trainer/sft_trainer.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
1 parent 29d439a commit 4edc688

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

docs/source/sft_trainer.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ trainer = SFTTrainer(
251251

252252
trainer.train()
253253
```
254-
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
254+
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
255255

256256
### Packing dataset ([`ConstantLengthDataset`])
257257

trl/trainer/sft_trainer.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
import torch.nn as nn
22+
from accelerate.state import PartialState
2223
from datasets import Dataset
2324
from datasets.arrow_writer import SchemaInferenceError
2425
from datasets.builder import DatasetGenerationError
@@ -252,27 +253,13 @@ def make_inputs_require_grad(module, input, output):
252253
if data_collator is None:
253254
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
254255

255-
if dataset_kwargs is None:
256-
dataset_kwargs = {}
257-
if train_dataset is not None:
258-
train_dataset = self._prepare_dataset(
259-
train_dataset,
260-
tokenizer,
261-
packing,
262-
dataset_text_field,
263-
max_seq_length,
264-
formatting_func,
265-
num_of_sequences,
266-
chars_per_token,
267-
remove_unused_columns=args.remove_unused_columns if args is not None else True,
268-
**dataset_kwargs,
269-
)
270-
if eval_dataset is not None:
271-
_multiple = isinstance(eval_dataset, dict)
272-
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
273-
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
274-
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
275-
_eval_dataset,
256+
# Pre-process the datasets only once per node. The remaining processes will use the cache.
257+
if PartialState().is_local_main_process:
258+
if dataset_kwargs is None:
259+
dataset_kwargs = {}
260+
if train_dataset is not None:
261+
train_dataset = self._prepare_dataset(
262+
train_dataset,
276263
tokenizer,
277264
packing,
278265
dataset_text_field,
@@ -283,8 +270,24 @@ def make_inputs_require_grad(module, input, output):
283270
remove_unused_columns=args.remove_unused_columns if args is not None else True,
284271
**dataset_kwargs,
285272
)
286-
if not _multiple:
287-
eval_dataset = _eval_datasets["singleton"]
273+
if eval_dataset is not None:
274+
_multiple = isinstance(eval_dataset, dict)
275+
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
276+
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
277+
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
278+
_eval_dataset,
279+
tokenizer,
280+
packing,
281+
dataset_text_field,
282+
max_seq_length,
283+
formatting_func,
284+
num_of_sequences,
285+
chars_per_token,
286+
remove_unused_columns=args.remove_unused_columns if args is not None else True,
287+
**dataset_kwargs,
288+
)
289+
if not _multiple:
290+
eval_dataset = _eval_datasets["singleton"]
288291

289292
if tokenizer.padding_side is not None and tokenizer.padding_side != "right":
290293
warnings.warn(

0 commit comments

Comments
 (0)