Skip to content

Conversation

@djsaunde
Copy link
Collaborator

@djsaunde djsaunde commented Jul 15, 2025

Description

This PR updates the existing logic gated by the fp8 config to work correctly with the current state of accelerate torchao fp8 support. It also adds the fp8_enable_fsdp_float8_all_gather config which is passed through to the Float8LinearConfig, and can speed up runs with FSDP enabled by up to ~50% depending on the configuration.

Motivation and Context

torchao.float8.convert_to_float8_training converts Linear layers in the given model to Float8Linear layers with a cast config (e.g., cast_configs=i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2"). Model weights remain in the dtype specified in the user's config (e.g., bfloat16), but ultimately get converted to fp8 at matmul time (forward and backward); see here.

This kind of fp8 training should be faster than bfloat16 training with torch.compile applied. torch.compile currently breaks (multi-GPU only + DDP) with:

[rank0]: RuntimeError: 
[rank0]: During the backward, we encountered a tensor subclass where we guessed its
[rank0]: metadata incorrectly.

[rank0]: Expected metadata: {'_orig_dtype': torch.bfloat16, '_linear_mm_config': LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), '_gemm_input_role': <GemmInputRole.WEIGHT: 'weight'>, '_axiswise_dim': None}, expected type: <class 'torchao.float8.float8_tensor.Float8Tensor'>

How has this been tested?

  • Manual testing
    • Single GPU
    • Multi GPU (DDP)
    • Multi GPU (FSDP)
    • Single GPU + torch.compile
    • Multi GPU (DDP) + torch.compile
    • Multi GPU (FSDP) + torch.compile
  • Add smoke tests

Not 100% sure, but some of the marked tested settings seem flaky depending on the run / hardware?

TODO:

  • Fix torch.compile metadata error for multi-GPU settings
  • Enumerate all other common features this is compatible with
  • Understand performance gain / degradation depending on features / hyperparams
  • Add docs

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features
    • Added comprehensive documentation on mixed precision training, including FP16, BF16, and experimental FP8 formats.
    • Introduced support for FP8 mixed precision training with new configuration options, including FSDP float8 all-gather.
    • Provided an example configuration for training LLaMA-3 3B with FP8 and FSDP2.
  • Bug Fixes
    • Added validation and warnings for optimal FP8 training setup.
  • Tests
    • Implemented new end-to-end tests for FP8 training on both single and multi-GPU setups, including FSDP2 and LoRA configurations.
    • Added unit tests to ensure compatibility of FP8-related monkeypatches with upstream code.
  • Documentation
    • Updated website sidebar to include new mixed precision and gradient accumulation guides.

@djsaunde djsaunde self-assigned this Jul 15, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 15, 2025

📝 Walkthrough
## Walkthrough

This change introduces support and documentation for FP8 mixed precision training in the Axolotl framework. It adds new documentation, configuration options, and internal logic to enable and validate FP8 and related FSDP options. New end-to-end tests are included for both single-GPU and multi-GPU (FSDP2) FP8 training scenarios, along with an example configuration.

## Changes

| Files/Paths                                                                 | Change Summary                                                                                                             |
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|
| docs/mixed_precision.qmd, docs/gradient_accumulation.qmd, _quarto.yml       | Added documentation for mixed precision and gradient accumulation; updated sidebar configuration to include new docs.      |
| examples/llama-3/3b-fp8-fsdp2.yaml                                          | Added example YAML config for LLaMA-3 3B FP8 training with FSDP2 and Liger plugin.                                        |
| src/axolotl/core/trainers/base.py                                           | Updated `additional_accelerator_args` method to support FP8 and FSDP float8 all-gather; changed signature and logic.      |
| src/axolotl/monkeypatch/trainer_accelerator_args.py                         | Modified monkeypatch function to accept and propagate the FSDP float8 all-gather flag in patched code.                    |
| src/axolotl/loaders/patch_manager.py                                        | Passed new config flag to FP8 patching function within patch manager.                                                     |
| src/axolotl/utils/schemas/config.py                                         | Added `fp8_enable_fsdp_float8_all_gather` optional boolean to input config schema with descriptions.                      |
| src/axolotl/utils/schemas/validation.py                                     | Added validator to warn if FP8 is enabled without torch.compile or with FSDP activation checkpointing; enforces FSDP2 for all-gather. |
| tests/e2e/integrations/test_fp8.py                                          | Added single-GPU FP8 mixed precision smoke test.                                                                          |
| tests/e2e/multigpu/test_fp8_fsdp2.py                                        | Added multi-GPU (FSDP2) FP8 mixed precision smoke tests, including training success verification.                         |
| tests/monkeypatch/test_trainer_accelerator_args.py                          | Added unit test to verify patch compatibility for FP8 accelerator args monkeypatch.                                       |

## Estimated code review effort

3 (~45 minutes)

## Possibly related PRs

- axolotl-ai-cloud/axolotl#2760: Related to FSDP2 infrastructure improvements and FP8 integration with FSDP2 float8 all-gather features.
- axolotl-ai-cloud/axolotl#2680: Related refactor of loaders and patch management that the FP8 patch builds upon.

## Suggested labels

`scheduled_release`, `ready to merge`

## Suggested reviewers

- winglian
✨ Finishing Touches
  • 📝 Generate Docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link
Contributor

github-actions bot commented Jul 15, 2025

📖 Documentation Preview: https://687fee83f99faab581f678cc--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 9f8cc4d

@codecov
Copy link

codecov bot commented Jul 15, 2025

Codecov Report

Attention: Patch coverage is 85.00000% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/schemas/validation.py 70.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@djsaunde
Copy link
Collaborator Author

djsaunde commented Jul 16, 2025

Quick notes from today:

Testing llama 3.1 8B, for large enough sequence length, we see speed improvements (~10-30% more iter/s) using torchao fp8 + torch.compile + either SDPA or flex attention on a single GPU, over several different settings of sequence length, batch size. Flash attention cannot be torch compiled and results in slower overall training times when used in conjuction with torchao fp8.

Multi-GPU is another story. FSDP2 + torchao fp8 + torch.compile results in significantly (~25%) fewer iter/s for 2x H100s, while DDP just fails with this error.

@S1ro1
Copy link

S1ro1 commented Jul 16, 2025

Testing llama 3.1 8B, for large enough sequence length, we see speed improvements (~10-30% more iter/s) using torchao fp8 + torch.compile + either SDPA or flex attention on a single GPU, over several different settings of sequence length, batch size. Flash attention cannot be torch compiled and results in slower overall training times when used in conjuction with torchao fp8.

Multi-GPU is another story. FSDP2 + torchao fp8 + torch.compile results in significantly (~25%) fewer iter/s for 2x H100s, while DDP just fails with this error.

So there is this very helpful chart in torchao docs that should tell a lot about where to expect what speedup. Regarding distributed, I haven't really tested more configurations, but on a single node with large enough seq-len (8192) you can see quite a nice improvement. Have a script to reproduce that in Accelerate: here

@djsaunde
Copy link
Collaborator Author

Testing llama 3.1 8B, for large enough sequence length, we see speed improvements (~10-30% more iter/s) using torchao fp8 + torch.compile + either SDPA or flex attention on a single GPU, over several different settings of sequence length, batch size. Flash attention cannot be torch compiled and results in slower overall training times when used in conjuction with torchao fp8.
Multi-GPU is another story. FSDP2 + torchao fp8 + torch.compile results in significantly (~25%) fewer iter/s for 2x H100s, while DDP just fails with this error.

So there is this very helpful chart in torchao docs that should tell a lot about where to expect what speedup. Regarding distributed, I haven't really tested more configurations, but on a single node with large enough seq-len (8192) you can see quite a nice improvement. Have a script to reproduce that in Accelerate: here

Thx! Yeah, that chart is nice. I'm seeing speedups in line with what they state for single GPU at least. Something appears to be bugged on our side for DDP and, to a lesser extent, FSDP2. Appreciate the script!

@winglian
Copy link
Collaborator

What if you disable checkpointing/offloading?

@djsaunde
Copy link
Collaborator Author

What if you disable checkpointing/offloading?

Same result unfortunately.

@djsaunde
Copy link
Collaborator Author

djsaunde commented Jul 17, 2025

What if you disable checkpointing/offloading?

Same result unfortunately.

I'm dumb, disabling fsdp gradient checkpointing allows fp8 training to go ~10-15% faster than bf16 for the models / hyperparams I tested. @SalmanMohammadi pointed out that torchtune's FFT uses gradient checkpointing by default, and their benchmarks show a ~15% improvement, so something is fishy here.

Forgive the long run names, thought I'd place all the relevant config info in them to tell the full story:

W B Chart 7_17_2025, 3_36_56 PM

(I can't get wandb to not cut off the green run's name, but that's fp8 + fsdp2 + no activation checkpointing)

@djsaunde djsaunde marked this pull request as ready for review July 22, 2025 14:17
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (1)
examples/llama-3/3b-fp8-fsdp2.yaml (1)

51-53: Enable gradient scaling unless confirmed unnecessary

FP8 + FSDP all-gather occasionally under-flows gradients. Consider adding gradient_scaling: dynamic (or similar hook) to stay safe.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b86a1d4 and 4b246f7.

📒 Files selected for processing (10)
  • _quarto.yml (1 hunks)
  • docs/mixed_precision.qmd (1 hunks)
  • examples/llama-3/3b-fp8-fsdp2.yaml (1 hunks)
  • src/axolotl/core/trainers/base.py (2 hunks)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/trainer_accelerator_args.py (3 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
  • tests/e2e/integrations/test_fp8.py (1 hunks)
  • tests/e2e/multigpu/test_fp8_fsdp2.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (11)
_quarto.yml (1)

271-272: LGTM! Good documentation structure.

The addition of mixed precision and gradient accumulation guides to the "Core Concepts" section aligns well with the FP8 training feature introduction and provides users with essential documentation.

src/axolotl/loaders/patch_manager.py (1)

157-159: LGTM! Configuration propagation is correct.

The modification properly passes the new FP8 FSDP configuration to the patching function. The change is well-gated by the existing if self.cfg.fp8: condition and maintains consistency with the broader FP8 training infrastructure.

src/axolotl/utils/schemas/validation.py (1)

363-381: LGTM: Well-implemented FP8 configuration validation.

The validation logic effectively guides users toward optimal FP8 configurations by warning about:

  1. Missing torch_compile which is essential for FP8 performance gains (as noted in PR objectives)
  2. Potential performance degradation with FSDP activation checkpointing

The implementation follows established patterns in the codebase and provides clear, actionable guidance.

src/axolotl/core/trainers/base.py (2)

10-10: LGTM: Proper type annotations and method signature.

The addition of explicit type annotations and return type improves code clarity and type safety. The default values for both FP8 parameters are appropriate.

Also applies to: 525-528


532-542: LGTM: Enhanced FP8 configuration with proper parameterization.

The updated implementation provides better control over FP8 behavior:

  • Uses Float8LinearConfig for fine-grained configuration
  • Properly passes the enable_fsdp_float8_all_gather parameter
  • Includes helpful documentation about the tensorwise scaling strategy
  • Maintains proper conditional imports

This is a significant improvement over a basic configuration approach.

tests/e2e/multigpu/test_fp8_fsdp2.py (3)

21-48: LGTM: Comprehensive training success validation.

The helper function provides thorough validation by checking:

  • Model file artifacts (.bin/.safetensors)
  • Checkpoint directory creation
  • Training loss values from TensorBoard logs (ensuring no NaN losses)

This multi-layered approach effectively verifies that FP8 training completed successfully.


53-119: LGTM: Well-configured FP8+FSDP2 smoke test.

The test properly configures:

  • FP8 mixed precision with FSDP float8 all-gather enabled
  • FSDP2 with appropriate settings for the SmolLM2 architecture
  • Minimal training steps for efficient smoke testing
  • Proper distributed execution with 2 GPUs

The configuration aligns with the PR objectives for testing FP8 training with FSDP optimizations.


121-193: LGTM: Comprehensive FP8+FSDP2+LoRA compatibility test.

This test validates the complex combination of:

  • FP8 mixed precision training
  • FSDP2 distributed training
  • LoRA adapters

The LoRA configuration is appropriate for testing, and maintaining the same FP8/FSDP2 settings ensures we're testing the interaction between all three technologies.

tests/e2e/integrations/test_fp8.py (1)

18-61: LGTM: Well-structured single-GPU FP8 smoke test.

The test provides good coverage for single-GPU FP8 training:

  • Enables both fp8 and torch_compile as recommended by the validation logic
  • Uses minimal configuration for efficient smoke testing
  • Follows proper Axolotl training pipeline (validate → normalize → load datasets → train)
  • Verifies successful completion through output file checking

This complements the multi-GPU tests nicely by covering the single-GPU use case.

docs/mixed_precision.qmd (1)

88-93: Good emphasis on compile requirement

Clear warning, nicely highlighted.

examples/llama-3/3b-fp8-fsdp2.yaml (1)

1-1: Base-model ID is valid on HF Hub

The Hugging Face API returns HTTP 200 for meta-llama/Llama-3.2-3B, confirming this model exists and can be fetched. No changes needed.

• File: examples/llama-3/3b-fp8-fsdp2.yaml, line 1.

Likely an incorrect or invalid review comment.

@djsaunde djsaunde changed the title basic torchao fp8 training basic torchao fp8 mixed precision training Jul 22, 2025
Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a quick test case like assert check_create_accelerate_code_is_patchable() so we get a test failure if the patched code changes upstream?

Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a minor test addition and nit about informing next step on validation. Good to go otherwise once those are addressed

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)

17-22: Use unittest assertion methods instead of raw assert.

While the test logic is sound, consider using self.assertTrue() instead of raw assert for better error reporting and to avoid issues when Python optimizations are enabled.

-        assert check_create_accelerate_code_is_patchable()
+        self.assertTrue(check_create_accelerate_code_is_patchable())
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d0d1160 and 9f8cc4d.

📒 Files selected for processing (2)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
  • tests/monkeypatch/test_trainer_accelerator_args.py (1 hunks)
🧬 Code Graph Analysis (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)
src/axolotl/monkeypatch/trainer_accelerator_args.py (1)
  • check_create_accelerate_code_is_patchable (35-38)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/schemas/validation.py
🧰 Additional context used
🧬 Code Graph Analysis (1)
tests/monkeypatch/test_trainer_accelerator_args.py (1)
src/axolotl/monkeypatch/trainer_accelerator_args.py (1)
  • check_create_accelerate_code_is_patchable (35-38)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (12)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.0, vllm, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.6.0, 2, true)
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, 2, true)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
tests/monkeypatch/test_trainer_accelerator_args.py (3)

1-10: Well-structured imports and documentation.

The file header and import organization follow best practices with clear documentation and appropriate import specificity.


12-16: Proper test class structure.

The test class follows unittest conventions with clear documentation and appropriate naming.


25-27: Standard test execution pattern.

The main execution block follows Python testing conventions appropriately.

@djsaunde
Copy link
Collaborator Author

Going to merge and hope to address a few follow-ups in additional PRs.

@djsaunde djsaunde merged commit 208fb7b into main Jul 22, 2025
15 of 17 checks passed
@djsaunde djsaunde deleted the fp8-support-v2 branch July 22, 2025 20:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants