Skip to content

Commit 63a37d0

Browse files
committed
[megatron] fix: MTP patch for newer mcore
There's no compute_output_layer_and_language_model_loss in new mcore and everything needs to be handled by process_mtp_loss Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent 5e2f2b2 commit 63a37d0

File tree

1 file changed

+80
-48
lines changed

1 file changed

+80
-48
lines changed

verl/models/mcore/mtp_patch.py

Lines changed: 80 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
roll_tensor,
2727
)
2828

29+
try:
30+
import megatron.core.transformer.multi_token_prediction as _mtp_module
31+
except ImportError:
32+
_PROCESS_MTP_LOSS: Callable | None = None
33+
else:
34+
_PROCESS_MTP_LOSS: Callable | None = getattr(_mtp_module, "process_mtp_loss", None)
35+
2936
try:
3037
from megatron.core.utils import unwrap_model
3138
except ImportError:
@@ -78,6 +85,7 @@ def _megatron_gptmodel_postprocess(
7885
runtime_gather_output=None,
7986
extra_block_kwargs=None,
8087
inference_context=None,
88+
**kwargs,
8189
):
8290
"""Postprocesses decoder hidden states to generate logits or compute loss.
8391
@@ -111,58 +119,82 @@ def _megatron_gptmodel_postprocess(
111119

112120
# Skip when mtp_num_layers is None or 0
113121
if self.config.mtp_num_layers and labels is not None:
114-
mtp_labels = labels.clone()
115-
116-
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
117-
hidden_states = hidden_states_list[0]
118-
if loss_mask is None:
119-
# if loss_mask is not provided, use all ones as loss_mask
120-
loss_mask = torch.ones_like(mtp_labels)
121-
for mtp_layer_number in range(self.config.mtp_num_layers):
122-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
123-
mtp_labels, _ = roll_tensor(
124-
mtp_labels,
125-
shifts=-1,
126-
dims=-1,
127-
cp_group=self.cp_group,
128-
packed_seq_params=packed_seq_params,
129-
)
130-
loss_mask, num_tokens = roll_tensor(
131-
loss_mask,
132-
shifts=-1,
133-
dims=-1,
134-
cp_group=self.cp_group,
122+
# Prefer upstream helper when available (newer Megatron-LM), using
123+
# a cached reference resolved at module import time.
124+
if _PROCESS_MTP_LOSS:
125+
cp_group = None
126+
if getattr(self, "pg_collection", None) is not None:
127+
cp_group = self.pg_collection.cp
128+
elif hasattr(self, "cp_group"):
129+
cp_group = self.cp_group
130+
131+
hidden_states = _PROCESS_MTP_LOSS(
132+
hidden_states=hidden_states,
133+
labels=labels,
134+
loss_mask=loss_mask,
135+
output_layer=self.output_layer,
136+
output_weight=output_weight,
137+
runtime_gather_output=runtime_gather_output,
138+
is_training=self.training,
139+
compute_language_model_loss=self.compute_language_model_loss,
140+
config=self.config,
141+
cp_group=cp_group,
135142
packed_seq_params=packed_seq_params,
136143
)
144+
else:
145+
# Fallback for older Megatron-LM versions without process_mtp_loss API.
146+
mtp_labels = labels.clone()
147+
148+
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
149+
hidden_states = hidden_states_list[0]
150+
if loss_mask is None:
151+
# if loss_mask is not provided, use all ones as loss_mask
152+
loss_mask = torch.ones_like(mtp_labels)
153+
for mtp_layer_number in range(self.config.mtp_num_layers):
154+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
155+
mtp_labels, _ = roll_tensor(
156+
mtp_labels,
157+
shifts=-1,
158+
dims=-1,
159+
cp_group=self.cp_group,
160+
packed_seq_params=packed_seq_params,
161+
)
162+
loss_mask, num_tokens = roll_tensor(
163+
loss_mask,
164+
shifts=-1,
165+
dims=-1,
166+
cp_group=self.cp_group,
167+
packed_seq_params=packed_seq_params,
168+
)
137169

138-
# Compute mtp loss without storing logits to save memory.
139-
mtp_loss = self.compute_output_layer_and_language_model_loss(
140-
hidden_states_list[mtp_layer_number + 1],
141-
labels=mtp_labels,
142-
weight=self.shared_embedding_or_output_weight(),
143-
sequence_parallel_enabled=self.output_layer.sequence_parallel,
144-
column_parallel_linear=self.output_layer,
145-
col_linear_kwargs={
146-
"weight": output_weight,
147-
"runtime_gather_output": runtime_gather_output,
148-
},
149-
)
150-
151-
mtp_loss = loss_mask * mtp_loss
152-
if self.training:
153-
# TODO(shifangx): remove the use of parallel_state here
154-
# after moving loss logging to loss_func in pretrain_gpt.py
155-
MTPLossLoggingHelper.save_loss_to_tracker(
156-
torch.sum(mtp_loss) / num_tokens,
157-
mtp_layer_number,
158-
self.config.mtp_num_layers,
159-
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
170+
# Compute mtp loss without storing logits to save memory.
171+
mtp_loss = self.compute_output_layer_and_language_model_loss(
172+
hidden_states_list[mtp_layer_number + 1],
173+
labels=mtp_labels,
174+
weight=self.shared_embedding_or_output_weight(),
175+
sequence_parallel_enabled=self.output_layer.sequence_parallel,
176+
column_parallel_linear=self.output_layer,
177+
col_linear_kwargs={
178+
"weight": output_weight,
179+
"runtime_gather_output": runtime_gather_output,
180+
},
160181
)
161-
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
162-
if self.config.calculate_per_token_loss:
163-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
164-
else:
165-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
182+
183+
mtp_loss = loss_mask * mtp_loss
184+
if self.training:
185+
# TODO(shifangx): remove the use of parallel_state here
186+
# after moving loss logging to loss_func in pretrain_gpt.py
187+
MTPLossLoggingHelper.save_loss_to_tracker(
188+
torch.sum(mtp_loss) / num_tokens,
189+
mtp_layer_number,
190+
self.config.mtp_num_layers,
191+
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
192+
)
193+
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
194+
if self.config.calculate_per_token_loss:
195+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
196+
else:
197+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
166198

167199
logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
168200
# [s b h] => [b s h]

0 commit comments

Comments
 (0)