Skip to content

Fix BucketBatchSampler cache alignment in DreamBooth scripts#13353

Open
azolotenkov wants to merge 2 commits intohuggingface:mainfrom
azolotenkov:fix-bucket-batch-sampler-cache-alignment
Open

Fix BucketBatchSampler cache alignment in DreamBooth scripts#13353
azolotenkov wants to merge 2 commits intohuggingface:mainfrom
azolotenkov:fix-bucket-batch-sampler-cache-alignment

Conversation

@azolotenkov
Copy link
Copy Markdown
Contributor

@azolotenkov azolotenkov commented Mar 27, 2026

What does this PR do?

This PR fixes a bug where BucketBatchSampler reshuffled precomputed batches on each __iter__() call.

DreamBooth training scripts precompute latents and/or prompt embeddings and later consume them by dataloader step index, so changing batch order between the caching pass and the training pass can misalign the cached tensors with the current batch.

This PR removes the per-iteration random.shuffle(self.batches) call from BucketBatchSampler.__iter__() and instead shuffles the precomputed batch list once at sampler construction time, keeping the batch order fixed across epochs.

A future rework may still be needed to support true epoch-wise reshuffling, since batch membership is currently fixed once at sampler construction time.

Applied to:

  • examples/dreambooth/train_dreambooth_lora_flux2.py
  • examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
  • examples/dreambooth/train_dreambooth_lora_flux2_klein.py
  • examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
  • examples/dreambooth/train_dreambooth_lora_z_image.py

Before submitting

Who can review?

Training examples: @sayakpaul

Copilot AI review requested due to automatic review settings March 27, 2026 11:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes cache misalignment in DreamBooth training examples by making BucketBatchSampler yield batches in a stable, precomputed order across __iter__() calls, preventing reshuffles that break step-indexed latent/prompt-embedding caches.

Changes:

  • Removed random.shuffle(self.batches) from BucketBatchSampler.__iter__() in multiple DreamBooth example scripts.
  • Added an explanatory comment describing why batch order must remain stable for cache alignment.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
examples/dreambooth/train_dreambooth_lora_flux2.py Keep precomputed batch order stable in BucketBatchSampler.__iter__() to avoid cache misalignment.
examples/dreambooth/train_dreambooth_lora_flux2_img2img.py Same sampler iteration stabilization for cache alignment.
examples/dreambooth/train_dreambooth_lora_flux2_klein.py Same sampler iteration stabilization for cache alignment.
examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py Same sampler iteration stabilization for cache alignment.
examples/dreambooth/train_dreambooth_lora_z_image.py Same sampler iteration stabilization for cache alignment.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@azolotenkov
Copy link
Copy Markdown
Contributor Author

Updated based on review: the sampler now shuffles precomputed batches once at construction time, while keeping iteration order fixed across epochs for cache alignment.

@sayakpaul sayakpaul requested a review from linoytsaban March 27, 2026 11:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants