-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
basic torchao fp8 mixed precision training #2926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 Walkthrough## Walkthrough
This change introduces support and documentation for FP8 mixed precision training in the Axolotl framework. It adds new documentation, configuration options, and internal logic to enable and validate FP8 and related FSDP options. New end-to-end tests are included for both single-GPU and multi-GPU (FSDP2) FP8 training scenarios, along with an example configuration.
## Changes
| Files/Paths | Change Summary |
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|
| docs/mixed_precision.qmd, docs/gradient_accumulation.qmd, _quarto.yml | Added documentation for mixed precision and gradient accumulation; updated sidebar configuration to include new docs. |
| examples/llama-3/3b-fp8-fsdp2.yaml | Added example YAML config for LLaMA-3 3B FP8 training with FSDP2 and Liger plugin. |
| src/axolotl/core/trainers/base.py | Updated `additional_accelerator_args` method to support FP8 and FSDP float8 all-gather; changed signature and logic. |
| src/axolotl/monkeypatch/trainer_accelerator_args.py | Modified monkeypatch function to accept and propagate the FSDP float8 all-gather flag in patched code. |
| src/axolotl/loaders/patch_manager.py | Passed new config flag to FP8 patching function within patch manager. |
| src/axolotl/utils/schemas/config.py | Added `fp8_enable_fsdp_float8_all_gather` optional boolean to input config schema with descriptions. |
| src/axolotl/utils/schemas/validation.py | Added validator to warn if FP8 is enabled without torch.compile or with FSDP activation checkpointing; enforces FSDP2 for all-gather. |
| tests/e2e/integrations/test_fp8.py | Added single-GPU FP8 mixed precision smoke test. |
| tests/e2e/multigpu/test_fp8_fsdp2.py | Added multi-GPU (FSDP2) FP8 mixed precision smoke tests, including training success verification. |
| tests/monkeypatch/test_trainer_accelerator_args.py | Added unit test to verify patch compatibility for FP8 accelerator args monkeypatch. |
## Estimated code review effort
3 (~45 minutes)
## Possibly related PRs
- axolotl-ai-cloud/axolotl#2760: Related to FSDP2 infrastructure improvements and FP8 integration with FSDP2 float8 all-gather features.
- axolotl-ai-cloud/axolotl#2680: Related refactor of loaders and patch management that the FP8 patch builds upon.
## Suggested labels
`scheduled_release`, `ready to merge`
## Suggested reviewers
- winglian✨ Finishing Touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
|
📖 Documentation Preview: https://687fee83f99faab581f678cc--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 9f8cc4d |
Codecov ReportAttention: Patch coverage is
📢 Thoughts on this report? Let us know! |
|
Quick notes from today: Testing llama 3.1 8B, for large enough sequence length, we see speed improvements (~10-30% more iter/s) using torchao fp8 + torch.compile + either SDPA or flex attention on a single GPU, over several different settings of sequence length, batch size. Flash attention cannot be torch compiled and results in slower overall training times when used in conjuction with torchao fp8. Multi-GPU is another story. FSDP2 + torchao fp8 + torch.compile results in significantly (~25%) fewer iter/s for 2x H100s, while DDP just fails with this error. |
So there is this very helpful chart in torchao docs that should tell a lot about where to expect what speedup. Regarding distributed, I haven't really tested more configurations, but on a single node with large enough seq-len (8192) you can see quite a nice improvement. Have a script to reproduce that in Accelerate: here |
Thx! Yeah, that chart is nice. I'm seeing speedups in line with what they state for single GPU at least. Something appears to be bugged on our side for DDP and, to a lesser extent, FSDP2. Appreciate the script! |
|
What if you disable checkpointing/offloading? |
Same result unfortunately. |
I'm dumb, disabling fsdp gradient checkpointing allows fp8 training to go ~10-15% faster than bf16 for the models / hyperparams I tested. @SalmanMohammadi pointed out that torchtune's FFT uses gradient checkpointing by default, and their benchmarks show a ~15% improvement, so something is fishy here. Forgive the long run names, thought I'd place all the relevant config info in them to tell the full story:
(I can't get wandb to not cut off the green run's name, but that's fp8 + fsdp2 + no activation checkpointing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (1)
examples/llama-3/3b-fp8-fsdp2.yaml (1)
51-53: Enable gradient scaling unless confirmed unnecessaryFP8 + FSDP all-gather occasionally under-flows gradients. Consider adding
gradient_scaling: dynamic(or similar hook) to stay safe.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
_quarto.yml(1 hunks)docs/mixed_precision.qmd(1 hunks)examples/llama-3/3b-fp8-fsdp2.yaml(1 hunks)src/axolotl/core/trainers/base.py(2 hunks)src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/trainer_accelerator_args.py(3 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)tests/e2e/integrations/test_fp8.py(1 hunks)tests/e2e/multigpu/test_fp8_fsdp2.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
- GitHub Check: pre-commit
🔇 Additional comments (11)
_quarto.yml (1)
271-272: LGTM! Good documentation structure.The addition of mixed precision and gradient accumulation guides to the "Core Concepts" section aligns well with the FP8 training feature introduction and provides users with essential documentation.
src/axolotl/loaders/patch_manager.py (1)
157-159: LGTM! Configuration propagation is correct.The modification properly passes the new FP8 FSDP configuration to the patching function. The change is well-gated by the existing
if self.cfg.fp8:condition and maintains consistency with the broader FP8 training infrastructure.src/axolotl/utils/schemas/validation.py (1)
363-381: LGTM: Well-implemented FP8 configuration validation.The validation logic effectively guides users toward optimal FP8 configurations by warning about:
- Missing
torch_compilewhich is essential for FP8 performance gains (as noted in PR objectives)- Potential performance degradation with FSDP activation checkpointing
The implementation follows established patterns in the codebase and provides clear, actionable guidance.
src/axolotl/core/trainers/base.py (2)
10-10: LGTM: Proper type annotations and method signature.The addition of explicit type annotations and return type improves code clarity and type safety. The default values for both FP8 parameters are appropriate.
Also applies to: 525-528
532-542: LGTM: Enhanced FP8 configuration with proper parameterization.The updated implementation provides better control over FP8 behavior:
- Uses
Float8LinearConfigfor fine-grained configuration- Properly passes the
enable_fsdp_float8_all_gatherparameter- Includes helpful documentation about the tensorwise scaling strategy
- Maintains proper conditional imports
This is a significant improvement over a basic configuration approach.
tests/e2e/multigpu/test_fp8_fsdp2.py (3)
21-48: LGTM: Comprehensive training success validation.The helper function provides thorough validation by checking:
- Model file artifacts (.bin/.safetensors)
- Checkpoint directory creation
- Training loss values from TensorBoard logs (ensuring no NaN losses)
This multi-layered approach effectively verifies that FP8 training completed successfully.
53-119: LGTM: Well-configured FP8+FSDP2 smoke test.The test properly configures:
- FP8 mixed precision with FSDP float8 all-gather enabled
- FSDP2 with appropriate settings for the SmolLM2 architecture
- Minimal training steps for efficient smoke testing
- Proper distributed execution with 2 GPUs
The configuration aligns with the PR objectives for testing FP8 training with FSDP optimizations.
121-193: LGTM: Comprehensive FP8+FSDP2+LoRA compatibility test.This test validates the complex combination of:
- FP8 mixed precision training
- FSDP2 distributed training
- LoRA adapters
The LoRA configuration is appropriate for testing, and maintaining the same FP8/FSDP2 settings ensures we're testing the interaction between all three technologies.
tests/e2e/integrations/test_fp8.py (1)
18-61: LGTM: Well-structured single-GPU FP8 smoke test.The test provides good coverage for single-GPU FP8 training:
- Enables both
fp8andtorch_compileas recommended by the validation logic- Uses minimal configuration for efficient smoke testing
- Follows proper Axolotl training pipeline (validate → normalize → load datasets → train)
- Verifies successful completion through output file checking
This complements the multi-GPU tests nicely by covering the single-GPU use case.
docs/mixed_precision.qmd (1)
88-93: Good emphasis on compile requirementClear warning, nicely highlighted.
examples/llama-3/3b-fp8-fsdp2.yaml (1)
1-1: Base-model ID is valid on HF HubThe Hugging Face API returns HTTP 200 for
meta-llama/Llama-3.2-3B, confirming this model exists and can be fetched. No changes needed.• File:
examples/llama-3/3b-fp8-fsdp2.yaml, line 1.Likely an incorrect or invalid review comment.
winglian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a quick test case like assert check_create_accelerate_code_is_patchable() so we get a test failure if the patched code changes upstream?
winglian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor test addition and nit about informing next step on validation. Good to go otherwise once those are addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)
17-22: Use unittest assertion methods instead of raw assert.While the test logic is sound, consider using
self.assertTrue()instead of rawassertfor better error reporting and to avoid issues when Python optimizations are enabled.- assert check_create_accelerate_code_is_patchable() + self.assertTrue(check_create_accelerate_code_is_patchable())
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/axolotl/utils/schemas/validation.py(1 hunks)tests/monkeypatch/test_trainer_accelerator_args.py(1 hunks)
🧬 Code Graph Analysis (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)
src/axolotl/monkeypatch/trainer_accelerator_args.py (1)
check_create_accelerate_code_is_patchable(35-38)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)
src/axolotl/monkeypatch/trainer_accelerator_args.py (1)
check_create_accelerate_code_is_patchable(35-38)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
tests/monkeypatch/test_trainer_accelerator_args.py (3)
1-10: Well-structured imports and documentation.The file header and import organization follow best practices with clear documentation and appropriate import specificity.
12-16: Proper test class structure.The test class follows unittest conventions with clear documentation and appropriate naming.
25-27: Standard test execution pattern.The main execution block follows Python testing conventions appropriately.
|
Going to merge and hope to address a few follow-ups in additional PRs. |

Description
This PR updates the existing logic gated by the
fp8config to work correctly with the current state of accelerate torchao fp8 support. It also adds thefp8_enable_fsdp_float8_all_gatherconfig which is passed through to theFloat8LinearConfig, and can speed up runs with FSDP enabled by up to ~50% depending on the configuration.Motivation and Context
torchao.float8.convert_to_float8_trainingconvertsLinearlayers in the given model toFloat8Linearlayers with a cast config (e.g.,cast_configs=i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2"). Model weights remain in the dtype specified in the user's config (e.g., bfloat16), but ultimately get converted to fp8 at matmul time (forward and backward); see here.This kind of fp8 training should be faster than bfloat16 training with
torch.compileapplied.torch.compilecurrently breaks (multi-GPU only + DDP) with:How has this been tested?
Not 100% sure, but some of the marked tested settings seem flaky depending on the run / hardware?
TODO:
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit