Skip to content

[trainer] Initial Megatron TP + PP Support#223

Merged
erictang000 merged 22 commits intoNovaSky-AI:mainfrom
erictang000:megatron
Sep 5, 2025
Merged

[trainer] Initial Megatron TP + PP Support#223
erictang000 merged 22 commits intoNovaSky-AI:mainfrom
erictang000:megatron

Conversation

@erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Aug 29, 2025

Overview

This PR initializes Megatron support for GRPO with TP + PP (and DP) supported. We use mbridge to translate from huggingface to megatron, and plan to migrate to Megatron-Bridge once the repo is more mature and we can safely upgrade some dependencies.

What's not Included

For a full set of limitations see the code here:

def validate_megatron_cfg(cfg: DictConfig):
.

Note: there is a known issue with Qwen2.5-1.5B and Qwen2.5-3B models, which don't specify lm_head.weight in the model.safetensors file, causing a weight loading issue with mbridge - issue here: ISEEKYAN/mbridge#10.

Followups

For a full task list needed for full Megatron support, see #203.

Reproduction runs (gsm8k with Qwen3-0.6B)

image

TODOs

  • Port GPU test to be runnable without locally saved training batch input
  • fix entropy calculation
  • dependencies
  • add any needed doc-strings
  • code-attribution if i missed anything

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for Megatron-LM, enabling Tensor Parallelism (TP) and Pipeline Parallelism (PP) for training. The changes are extensive, adding new strategies, workers, configurations, and utilities for Megatron integration. Key changes include a new MegatronStrategy, MegatronPPOPolicy, and corresponding workers. The PR also includes valuable refactoring, such as moving common distributed functions to the base DistributedStrategy class, and adds validation for Megatron configurations to improve user experience. The addition of comprehensive tests for the new Megatron functionality is commendable.

My review includes a couple of suggestions for improvement. One is to implement the policy entropy calculation, which is currently missing. The other is to make a new utility function more robust.

I also noticed a few limitations with this initial implementation, such as the lack of checkpointing support and support for only the GRPO algorithm. It would be beneficial to document these limitations in the pull request description to set clear expectations for users.


[dependency-groups]

# This is a default group so that we install these even with bare `uv sync`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: figure out how to reconcile deps - flash-rl requires torch 2.7.0, but in order to build transformer-engine-pytorch, i need torch 2.7.1. I think I can get this working by asking the user to build transformer-engine first, but need to test out installation in a fresh env.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashRL extra is currently standalone, so for the initial integration we can just say that the two dep groups mcore and flashrl conflict, and that should be enough right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem was that i needed to include the build dependency-group as a default which clashes torch versions with flash-rl. but this is maybe just needed to build transformer-engine so can give instructions to build it separately include it as a no-build-isolation dep (btw this takes like more than 10 mins :(. )

will see how that goes tho

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see

expert_tensor_parallel_size: 1

# Allow to override Distributed Data Parallel (DDP) config
# TODO (erictang000): understand exactly what is going on here
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pending...

@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for Megatron-Core, enabling training with Tensor Parallelism (TP) and Pipeline Parallelism (PP). The integration relies on mbridge for converting Hugging Face models to the Megatron format. The changes include new dependencies, Megatron-specific configurations, distributed strategies, and worker implementations. The implementation correctly handles pipeline parallelism by collecting outputs from the last stage and introduces custom autograd functions for distributed log-probability calculations. Overall, this is a substantial and well-structured addition. My feedback focuses on improving configuration clarity, addressing potential brittleness in the implementation, and noting current limitations for future work.

@erictang000 erictang000 requested a review from SumanthRH August 30, 2025 00:48
Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Leaving some comments after a first pass

average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"

seed: 42
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is this seed argument needed here? It looks like we pass cfg.trainer.seed to MegatronStrategy at init, so we should not have it in the config here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

def offload_tensor_to_cpu(tensor):
if tensor is None:
return
tensor.data = tensor.data.to("cpu", non_blocking=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tensor.data = tensor.data.to("cpu", non_blocking=True)
tensor = tensor.to("cpu", non_blocking=True)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interestingly using .data here seems to actually impact offloaded memory!

using .data:
image

not using .data:
image

maybe we can leave these as is?

if tensor is None:
return
device_id = torch.cuda.current_device()
tensor.data = tensor.data.to(device_id, non_blocking=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tensor.data = tensor.data.to(device_id, non_blocking=True)
tensor = tensor.to(device_id, non_blocking=True)

logger.info("Synced registries to ray actor")


def _safe_exp_delta(delta: torch.Tensor, clip: float = 20.0, out_dtype=None) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small comment on overflow/ underflow here would be nice

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")

# override the init_process_group method to use megatron distributed setup to create the mesh
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment doesn't match the code? (or maybe is insufficient)

self.tokenizer = tokenizer

def make_megatron_module(self, model_config_kwargs, wrap_with_ddp=True):
model = self.bridge.get_model(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Type hint?

Looks like this is no longer using the model classes in models.py, good to highlight. (is this just an AutoBridge instance, or anything else)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


seq_len = sequences.shape[1]

new_sequences, new_attention_mask, new_position_ids = remove_left_padding(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw why is only left padding removed here? Is this related to a quirk for Mbridge models? We currently pad on both sides

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is needed for correctness for Megatron models (should be unrelated to the mbridge library). When I tested without it things were wrong

took the util function from verl where they do this: https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/models/mcore/model_forward.py#L66

self.actor_optimizer = actor_optimizer
self.policy_loss_fn = policy_loss_fn

# NOTE (erictang000): this is a potentially brittle way to disable the finalize_model_grads_func
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finalize_model_grads_func is this function : https://github.com/NVIDIA/Megatron-LM/blob/02a1dd02fc72307c1c6327c0667c9e7cecf2eaff/megatron/core/distributed/finalize_model_grads.py#L376

Doing a quick search, it is called at the end of the forward + backward op, ex: https://github.com/NVIDIA/Megatron-LM/blob/02a1dd02fc72307c1c6327c0667c9e7cecf2eaff/megatron/core/pipeline_parallel/schedules.py#L640-L647

Looking at the function, it's performing

  1. grad all reduce across dp replicas
  2. grad all reduce across SP ranks for layernorm layers
  3. grad all reduce for embedding layers for PP
  4. optionally for moe models run an update_expert_token_bias function
  5. scale gradient if num_tokens is provided

1, 2, and 3 should be unaffected if we do one finalize_ call at the end of the mini batch. As far as I understand, there should also not be any effect on memory usage here.

For 4., with a quick glance the function seems to be compatible again with making this finalize_ call at the end of the mini batch, but i haven't looked too closely. the local_token_per_expert attribute for the model is averaged across ranks. This value should acccumulate over micro batches, since I only see it being updated in the forward pass here: https://github.com/NVIDIA/Megatron-LM/blob/cca55ccdf408d56f97348f844a6aeb8f34d07672/megatron/core/transformer/moe/router.py#L448 , apart from the update on finalize_: https://github.com/NVIDIA/Megatron-LM/blob/cca55ccdf408d56f97348f844a6aeb8f34d07672/megatron/core/distributed/finalize_model_grads.py#L294-L295 .

There are actually some subtle ways this change could have an affect, for example, this apply_z_loss function for the router depends on the : https://github.com/NVIDIA/Megatron-LM/blob/cca55ccdf408d56f97348f844a6aeb8f34d07672/megatron/core/transformer/moe/router.py#L333-L366 . This doesn't seem relevant to us AFAIK because we set calculate_per_token_loss=False.

For 5., I again don't think it's relevant since we set calculate_per_token_loss=False

Btw, I see another comment on the utility of this finalize_ call after each micro batch here: https://github.com/NVIDIA/Megatron-LM/blob/cca55ccdf408d56f97348f844a6aeb8f34d07672/megatron/core/models/huggingface/module.py#L28-L34, but I haven't given this much thought.

Wdyt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for helping dig in more - maybe the conclusion from your last few points is that it might be better to try to refactor this to use the built in megatron forward_backward_func instead of fighting their abstraction to fit into our previous setup? Especially if there's places we are uncertain if it would affect correctness.

I'll try writing a version of the forward_backward_func that does the whole training step and make sure metrics match with the current implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me, thanks!

Comment on lines 32 to 34
# NOTE (erictang000): in the normal training flow, the model weights get downloaded in the HF cache by the inference engine
# we should figure out a way to download the model weights in the test environment (for now trying to do this by
# ordering the tests so the inference engine is initialized first). Also will be an issue for disaggregated training
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the exact issue you faced here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the megatron worker it just does AutoConfig(Qwen/Qwen3-0.6B) and inside the mbridge library while loading weights it assumes that the model weights are also there at like /home/ray/.cache/huggingface/hub/models/Qwen/Qwen3-0.6B.

This is true if we are colocated and the inference engine is initialized first (since it'll download the weights and configs to the cache. But not true if we just run a worker test without inference engine init.

Copy link
Member

@SumanthRH SumanthRH Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm okay.

Ideally we do this download ourselves then before init, it's pretty simple:

from huggingface_hub import snapshot_download
if rank == 0:
    snapshot_download(model_name) # will be no-op if already downloaded
torch.distributed.barrier()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm does the rank==0 check work multi-node? this won't be an issue colocated but for disagg we might need to download to each machine? can leave a note to handle this later

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*local rank == 0

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sick i forgot we had that

Comment on lines 157 to 166
output = forward_backward_func(
forward_step_func=forward_step,
data_iterator=batch_generator,
model=self.actor_module,
num_microbatches=1,
seq_length=seq_len, # no use when input_shapes was set
micro_batch_size=micro_batch_size, # no use when input_shapes was set
forward_only=True,
)

Copy link
Member

@SumanthRH SumanthRH Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems incorrect, and from what I understand, we will end up with no pipelining for pipeline parallelism (no interleaving). num_microbatches=1

I actually assumed that the "microbatch" for pipelining is the same as the "micro batch" in our terminology. Can we modify the code to do that?

We currently have a synchronization point at the end of each forward_backward_micro_batch where we materialize loss tensors to CPU, so there's no overlap. For pipelining, it would probably be best to just feed a mini batch directly to megatron and use num_microbatches=mini_batch_size_per_gpu/micro_batch_size_per_gpu or so

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved by above refactoring

Comment on lines 96 to 97
"vllm==0.10.0",
"torch==2.7.1",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase on main?

@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for Megatron-LM, enabling tensor and pipeline parallelism for training. The changes are extensive, adding new strategies, workers, configurations, and utilities. The code is well-structured, and the inclusion of new GPU tests is a great step towards ensuring correctness. My feedback focuses on a few key areas: completing the implementation of features like checkpointing, improving performance in weight synchronization and function compilation, and addressing a couple of correctness and robustness issues. Overall, this is a solid foundation for full Megatron support.

@erictang000
Copy link
Collaborator Author

erictang000 commented Sep 4, 2025

note: after refactoring I'm seeing slightly slower convergence with DP=4, and grad norm is almost 2x higher. Tried tuning all the ddp config knobs, but none seem to be relevant. DP=1 runs seem to converge slightly faster and also have lower grad norm.

image

Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving, let's disable Megatron DDP for now and fix in a future PR

@erictang000
Copy link
Collaborator Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces initial support for Megatron-LM with Tensor Parallelism (TP) and Pipeline Parallelism (PP), which is a significant and valuable addition. The implementation is comprehensive, including a new distributed strategy, worker implementations, utility functions, configurations, and extensive tests. The code is well-structured and thoughtfully integrated into the existing framework.

My review focuses on correctness and performance opportunities. I've identified a potential correctness issue in a custom autograd function due to an in-place modification of a tensor saved for the backward pass, which could lead to incorrect gradients. Additionally, I've pointed out a performance bottleneck in the weight synchronization logic where collective communication operations are performed inside a loop, and I've suggested a more efficient batching approach.

The rest of the changes, including the refactoring of distributed strategies and configuration updates, look solid. The addition of thorough tests for the new Megatron functionality is particularly commendable and crucial for such a complex feature.

@erictang000 erictang000 merged commit ea97943 into NovaSky-AI:main Sep 5, 2025
3 checks passed
erictang000 added a commit that referenced this pull request Sep 8, 2025
Fix for problem described in:
#223 (comment).
Also adds a basic megatron only torch profiler for dev.

Issue was not setting finalize_model_grads (which is [default
None](https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/model_parallel_config.py#L95)
for some reason...)

Fixed loss and grad norm:

<img width="372" height="653" alt="image"
src="https://github.com/user-attachments/assets/cbe82a9c-9528-4c4d-8aa7-84a61a11816e"
/>


Fixed comparison to tp=4 training:

<img width="373" height="655" alt="image"
src="https://github.com/user-attachments/assets/c593c6b6-9056-4273-85e6-0dc958714662"
/>
ztcanddota added a commit to ztcanddota/skyagent that referenced this pull request Sep 28, 2025
Fix for problem described in:
NovaSky-AI/SkyRL#223 (comment).
Also adds a basic megatron only torch profiler for dev.

Issue was not setting finalize_model_grads (which is [default
None](https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/model_parallel_config.py#L95)
for some reason...)

Fixed loss and grad norm:

<img width="372" height="653" alt="image"
src="https://github.com/user-attachments/assets/cbe82a9c-9528-4c4d-8aa7-84a61a11816e"
/>


Fixed comparison to tp=4 training:

<img width="373" height="655" alt="image"
src="https://github.com/user-attachments/assets/c593c6b6-9056-4273-85e6-0dc958714662"
/>
SungjunlaLee added a commit to SungjunlaLee/SkyRL that referenced this pull request Jan 3, 2026
Fix for problem described in:
NovaSky-AI/SkyRL#223 (comment).
Also adds a basic megatron only torch profiler for dev.

Issue was not setting finalize_model_grads (which is [default
None](https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/model_parallel_config.py#L95)
for some reason...)

Fixed loss and grad norm:

<img width="372" height="653" alt="image"
src="https://github.com/user-attachments/assets/cbe82a9c-9528-4c4d-8aa7-84a61a11816e"
/>


Fixed comparison to tp=4 training:

<img width="373" height="655" alt="image"
src="https://github.com/user-attachments/assets/c593c6b6-9056-4273-85e6-0dc958714662"
/>
dzorlu referenced this pull request in fleet-ai/SkyRL Feb 4, 2026
# Overview
This PR initializes Megatron support for GRPO with TP + PP (and DP)
supported. We use [mbridge](https://github.com/ISEEKYAN/mbridge) to
translate from huggingface to megatron, and plan to migrate to
[Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) once
the repo is more mature and we can safely upgrade some dependencies.

## What's not Included
For a full set of limitations see the code here:
https://github.com/NovaSky-AI/SkyRL/blob/1350d7be6a340587596c91cd1e3854c6900fe3ce/skyrl-train/skyrl_train/utils/utils.py#L117.

Note: there is a known issue with `Qwen2.5-1.5B` and `Qwen2.5-3B`
models, which don't specify `lm_head.weight` in the model.safetensors
file, causing a weight loading issue with `mbridge` - issue here:
ISEEKYAN/mbridge#10.

## Followups
For a full task list needed for full Megatron support, see #203.

## Reproduction runs (gsm8k with Qwen3-0.6B)
<img width="848" height="532" alt="image"
src="https://github.com/user-attachments/assets/8e985969-ac4a-4391-a82e-eb77088eb941"
/>

## TODOs

- [x] Port GPU test to be runnable without locally saved training batch
input
- [x] fix entropy calculation
- [x] dependencies
- [x] add any needed doc-strings
- [x] code-attribution if i missed anything
dzorlu pushed a commit to fleet-ai/SkyRL that referenced this pull request Feb 4, 2026
Fix for problem described in:
NovaSky-AI#223 (comment).
Also adds a basic megatron only torch profiler for dev.

Issue was not setting finalize_model_grads (which is [default
None](https://github.com/NVIDIA/Megatron-LM/blob/ba97a7e282a8478a02d012bc9b9e45f3a6be216e/megatron/core/model_parallel_config.py#L95)
for some reason...)

Fixed loss and grad norm:

<img width="372" height="653" alt="image"
src="https://github.com/user-attachments/assets/cbe82a9c-9528-4c4d-8aa7-84a61a11816e"
/>


Fixed comparison to tp=4 training:

<img width="373" height="655" alt="image"
src="https://github.com/user-attachments/assets/c593c6b6-9056-4273-85e6-0dc958714662"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants