[Megatron] Add checkpointing support#298
Conversation
There was a problem hiding this comment.
Note: These trainer updates may need to be changed after #297 is merged
| self.init_weight_sync_state() | ||
|
|
||
| # Load policy model to GPU before loading checkpoint. | ||
| if self.cfg.trainer.placement.colocate_all: |
There was a problem hiding this comment.
Note: Policy model needs to be on GPU for Megatron load_checkpoint as required by Megatron's dist_checkpoint library
| print("Phase 3: Verify state consistency") | ||
|
|
||
| # Compare captured states | ||
| for key in state_before: |
There was a problem hiding this comment.
Note: this was only confirming global_step, which we already do
erictang000
left a comment
There was a problem hiding this comment.
pretty much LGTM! thanks. Let's just merge main after #297 is merged to make the trainer changes consistent
for multi-node checkpointing and save_hf_model, we'll want these but I can help test out multi-node as I get model training running, and we can rely on external scripts for converting megatron to HF for a little bit.
| # All ranks wait for the checkpoint directory to be created before saving. | ||
| dist.barrier() | ||
|
|
||
| # Collect the sharded state dicts for model and optimizer, and full state dict for the scheduler. |
There was a problem hiding this comment.
so other than the scheduler there's no other additional memory load or communication here?
There was a problem hiding this comment.
If I understand your question correctly, no! We should just be loading the sharded model and optimizer state dicts.
# What does this PR do? #298 broke GPU CI a bit: 1. Megatron related dependencies have not been resolved properly on `main` yet, and this test should be skipped. 2. We use a simple 4xL4 instance, but then the test was modified in #298 to request 8 GPUs (non-colocated training) --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
## What does this PR do? This PR implements support for `save_checkpoint` and `load_checkpoint` for the Megatron training backend. We use Megatron's `dist_checkpointing` library to perform checkpointing in parallel across ranks, which also allows for reloading the checkpoints in a different parallelism scheme. Other minor changes: * Rename `save/load_ckpt` to `save/load_checkpoint` * Removed unused arguments to `save/load_checkpoint`, primarily the inclusion of non-backend specific state (`tag`, `client_state`, `global_step`). This change keeps training backend checkpointing logic focused on the training backend's state. ## Testing * Extended two GPU checkpointing tests to cover Megatron * Also moves `test_save_load_checkpoint.py` into `gpu_ci`. Note, however, that Megatron tests are disabled in CI because they currently require a different `flash-attn` install. * Manually save-and-resumed several times: <img width="426" height="292" alt="Screenshot 2025-09-15 at 3 30 17 PM" src="https://github.com/user-attachments/assets/2a4170fe-fddd-4086-961f-ca3170e654ab" /> ## What's next? - [ ] Test multi-node checkpointing - [ ] Implement `save_hf_model`
# What does this PR do? NovaSky-AI#298 broke GPU CI a bit: 1. Megatron related dependencies have not been resolved properly on `main` yet, and this test should be skipped. 2. We use a simple 4xL4 instance, but then the test was modified in NovaSky-AI#298 to request 8 GPUs (non-colocated training) --------- Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
# What does this PR do? Fixes async trainer example after NovaSky-AI#298. We renamed `setup_policy_and_generator` to `init_weight_sync_state` but missed the update in some places. --------- Signed-off-by: SumanthRH <sumanthrh@anyscale.com>
What does this PR do?
This PR implements support for
save_checkpointandload_checkpointfor the Megatron training backend. We use Megatron'sdist_checkpointinglibrary to perform checkpointing in parallel across ranks, which also allows for reloading the checkpoints in a different parallelism scheme.Other minor changes:
save/load_ckpttosave/load_checkpointsave/load_checkpoint, primarily the inclusion of non-backend specific state (tag,client_state,global_step). This change keeps training backend checkpointing logic focused on the training backend's state.Testing
test_save_load_checkpoint.pyintogpu_ci. Note, however, that Megatron tests are disabled in CI because they currently require a differentflash-attninstall.What's next?
save_hf_model