Skip to content

Commit 9e2ac74

Browse files
authored
Provide safeguards during training (#168)
* fix: add safeguards during data processing Signed-off-by: Oleg S <[email protected]> * fix: add a safeguard for max_batch_len & max_seq_len in training We currently have certain values that need to be validated against others, but no logic to ensure that this works adequately. This commit provides a pre-training check that errors out if the value of max_batch_len is smaller than max_seq_len, since this breaks our ability to generate training batches Signed-off-by: Oleg S <[email protected]> * fix: add fallback logic to use the distributed sampler When we use the multipack sampler, it requires a certain shape of the dataset relative to the GPUs to be able to sufficiently distribute all of the samples across different nodes. When this happens, the train loaderbecomes empty which prevents us from being able to train. This commit resolves that issue by falling back to the distributed sampler when the multipack fails. Signed-off-by: Oleg S <[email protected]> --------- Signed-off-by: Oleg S <[email protected]>
1 parent b1a54c1 commit 9e2ac74

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/instructlab/training/data_process.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,18 @@ def main(args: DataProcessArgs):
204204
{"additional_special_tokens": ["<|pretrain|>", "<|/pretrain|>"]}
205205
)
206206

207-
data = load_dataset("json", data_files=args.data_path, split="train")
207+
try:
208+
data = load_dataset("json", data_files=args.data_path, split="train")
209+
except:
210+
# pylint: disable=raise-missing-from,broad-exception-raised
211+
raise Exception(
212+
"Malformed or missing data, please ensure that your dataset is not empty and correctly formatted"
213+
)
214+
215+
if data.num_rows == 0:
216+
raise ValueError(
217+
"The provided dataset is empty, please make sure that your dataset contains samples and try again."
218+
)
208219

209220
print(f"\033[92mtokenizing the dataset with {args.model_path} tokenizer...\033[0m")
210221
data_with_input_ids = data.map(
@@ -230,6 +241,10 @@ def main(args: DataProcessArgs):
230241
f"\033[36mat {args.max_seq_len} max sequence length, the number of samples to be dropped is {num_dropped_samples}\033[0m"
231242
)
232243
print(f"\033[36m({((num_dropped_samples / len(lens)) * 100):.2f}% of total)\033[0m")
244+
if num_dropped_samples == len(data):
245+
raise RuntimeError(
246+
f"Dataset does not contain any samples containing less than {args.max_seq_len=} tokens.\nPlease consider increasing your `max_seq_len` value, or adding more samples."
247+
)
233248

234249
lowest_10_percent = np.quantile(lens, (0 + np.arange(11)) / 100.0)
235250
for i, q in enumerate(lowest_10_percent):

src/instructlab/training/main_ds.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,25 @@ def main(args):
555555
sampler=args.sampler,
556556
seed=args.seed,
557557
)
558+
if len(train_loader) == 0:
559+
# this happens sometimes when we have more GPUs than data to process. In this case
560+
# we should either alert the user to switch samplers, or do it automatically and
561+
# warn them about it happening
562+
print(
563+
"\033[93mThe dataset is too small for multipack to distribute all of the samples across GPUs. Falling back to the distributed sampler!\033[0m"
564+
)
565+
args.sampler = "distributed"
566+
train_loader = setup_dataloader(
567+
dataset,
568+
tokenizer.pad_token_id,
569+
num_workers=8,
570+
is_granite=args.is_granite,
571+
max_batch_len=args.max_batch_len,
572+
packing_max_batch_len=packing_max_batch_len,
573+
samples_per_gpu=args.samples_per_gpu,
574+
sampler=args.sampler,
575+
seed=args.seed,
576+
)
558577

559578
if args.local_rank == 0:
560579
metric_logger.log_sync(
@@ -585,6 +604,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
585604
"""
586605
Wrapper around the main training job that calls torchrun.
587606
"""
607+
# early validation logic here
608+
if train_args.max_batch_len < train_args.max_seq_len:
609+
raise ValueError(
610+
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
611+
)
588612

589613
# process the training data
590614
if not os.path.exists(train_args.data_output_dir):

0 commit comments

Comments
 (0)