Skip to content

Commit ee76c17

Browse files
authored
Merge pull request #52 from ChanderG/bugfix-large-datapoints
bugfix: remove out large samples from the multi pack batch sampler
2 parents 01cbcd7 + 72fde51 commit ee76c17

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/instructlab/training/multipack_sampler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,11 @@ def generate_batches(self, set_stats=False):
437437
len(self.lengths)
438438
)
439439

440+
# remove indices where the entries are longer than batch max length
441+
indices = indices[self.lengths[indices] <= self.batch_max_length]
442+
if len(indices) < len(self.lengths):
443+
print(f"\033[33mDropping {len(self.lengths) - len(indices)} samples longer than batch_max_length. Ensure that the right max_batch_length is used during data processing.\033[0m")
444+
440445
lengths = self.lengths[indices]
441446
lengths_cumsum = np.cumsum(lengths)
442447

0 commit comments

Comments
 (0)