Skip to content

Bugfix: batch_size_warmup_scheduler was taking too long #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

onurgu
Copy link

@onurgu onurgu commented Feb 24, 2025

BatchSizeWarmupScheduler was taking too long or was impossible for real world max_batch_size values

When trying to use the training script like the following:

conda run -n bert24 composer main.py yamls/modernbert/modernbert-base-pretrain.yaml

the script was not giving any output for a long long while. So I started to read the code. I saw that the code was using sum(range(x, y)) idiom to summing the values along a range, this was inefficient for large y, especially impossible when y=50B or something.

Changes

Simplify BatchSizeWarmupScheduler Implementation

Summary

This PR simplifies the batch size warmup scheduling logic by replacing the step-based threshold calculation with a more straightforward token-based approach. The new implementation provides a more intuitive and mathematically precise way to handle batch size warmup during training.

Changes

  • Replaced _calculate_step_thresholds() with _calculate_tokens_per_batch_size()
  • Changed the scheduler to use token counts instead of steps for determining batch sizes
  • Simplified the batch size calculation using a direct mathematical formula
  • Updated method signatures to better reflect the token-based approach (current_stepcurrent_token_count)

Technical Details

The new implementation:

  1. Calculates total batch sizes using the arithmetic sequence sum formula: (n(a₁ + aₙ))/2
  2. Determines tokens per batch size unit by dividing warmup tokens by total batch sizes
  3. Uses integer division to determine how many batch size increments to apply

Benefits

  • More precise control over batch size progression
  • Simpler, more maintainable code with fewer loops and conditionals
  • Direct relationship between token count and batch size
  • Reduced memory footprint by eliminating the need to store threshold arrays

Discussions

If any, please include references to the relevant issues/previous PR/discord discussions around these changes.

Tests

  • Is the new feature tested? (Not always necessary for all changes -- just adding to the checklist to keep track)
  • Have you ran all the tests?
  • Do the tests all pass?
  • If not, have you included an explanation of which tests this PR breaks and/or why (below this checklisT)

@jihobak
Copy link

jihobak commented Mar 19, 2025

I have a question regarding your statement that using the 'sum(range(x, y))' idiom to sum values in a range is inefficient for large y – to the point of being impractical when y is around 50B, for example.

My understanding is that x and y are derived from batch size variables and are not related to the number of tokens. Could you clarify why you consider a scenario where y equals 50B?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants