|
26 | 26 | roll_tensor, |
27 | 27 | ) |
28 | 28 |
|
| 29 | +try: |
| 30 | + import megatron.core.transformer.multi_token_prediction as _mtp_module |
| 31 | +except ImportError: |
| 32 | + _PROCESS_MTP_LOSS: Callable | None = None |
| 33 | +else: |
| 34 | + _PROCESS_MTP_LOSS: Callable | None = getattr(_mtp_module, "process_mtp_loss", None) |
| 35 | + |
29 | 36 | try: |
30 | 37 | from megatron.core.utils import unwrap_model |
31 | 38 | except ImportError: |
@@ -78,6 +85,7 @@ def _megatron_gptmodel_postprocess( |
78 | 85 | runtime_gather_output=None, |
79 | 86 | extra_block_kwargs=None, |
80 | 87 | inference_context=None, |
| 88 | + **kwargs, |
81 | 89 | ): |
82 | 90 | """Postprocesses decoder hidden states to generate logits or compute loss. |
83 | 91 |
|
@@ -111,58 +119,82 @@ def _megatron_gptmodel_postprocess( |
111 | 119 |
|
112 | 120 | # Skip when mtp_num_layers is None or 0 |
113 | 121 | if self.config.mtp_num_layers and labels is not None: |
114 | | - mtp_labels = labels.clone() |
115 | | - |
116 | | - hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
117 | | - hidden_states = hidden_states_list[0] |
118 | | - if loss_mask is None: |
119 | | - # if loss_mask is not provided, use all ones as loss_mask |
120 | | - loss_mask = torch.ones_like(mtp_labels) |
121 | | - for mtp_layer_number in range(self.config.mtp_num_layers): |
122 | | - # Calc loss for the current Multi-Token Prediction (MTP) layers. |
123 | | - mtp_labels, _ = roll_tensor( |
124 | | - mtp_labels, |
125 | | - shifts=-1, |
126 | | - dims=-1, |
127 | | - cp_group=self.cp_group, |
128 | | - packed_seq_params=packed_seq_params, |
129 | | - ) |
130 | | - loss_mask, num_tokens = roll_tensor( |
131 | | - loss_mask, |
132 | | - shifts=-1, |
133 | | - dims=-1, |
134 | | - cp_group=self.cp_group, |
| 122 | + # Prefer upstream helper when available (newer Megatron-LM), using |
| 123 | + # a cached reference resolved at module import time. |
| 124 | + if _PROCESS_MTP_LOSS: |
| 125 | + cp_group = None |
| 126 | + if getattr(self, "pg_collection", None) is not None: |
| 127 | + cp_group = self.pg_collection.cp |
| 128 | + elif hasattr(self, "cp_group"): |
| 129 | + cp_group = self.cp_group |
| 130 | + |
| 131 | + hidden_states = _PROCESS_MTP_LOSS( |
| 132 | + hidden_states=hidden_states, |
| 133 | + labels=labels, |
| 134 | + loss_mask=loss_mask, |
| 135 | + output_layer=self.output_layer, |
| 136 | + output_weight=output_weight, |
| 137 | + runtime_gather_output=runtime_gather_output, |
| 138 | + is_training=self.training, |
| 139 | + compute_language_model_loss=self.compute_language_model_loss, |
| 140 | + config=self.config, |
| 141 | + cp_group=cp_group, |
135 | 142 | packed_seq_params=packed_seq_params, |
136 | 143 | ) |
| 144 | + else: |
| 145 | + # Fallback for older Megatron-LM versions without process_mtp_loss API. |
| 146 | + mtp_labels = labels.clone() |
| 147 | + |
| 148 | + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
| 149 | + hidden_states = hidden_states_list[0] |
| 150 | + if loss_mask is None: |
| 151 | + # if loss_mask is not provided, use all ones as loss_mask |
| 152 | + loss_mask = torch.ones_like(mtp_labels) |
| 153 | + for mtp_layer_number in range(self.config.mtp_num_layers): |
| 154 | + # Calc loss for the current Multi-Token Prediction (MTP) layers. |
| 155 | + mtp_labels, _ = roll_tensor( |
| 156 | + mtp_labels, |
| 157 | + shifts=-1, |
| 158 | + dims=-1, |
| 159 | + cp_group=self.cp_group, |
| 160 | + packed_seq_params=packed_seq_params, |
| 161 | + ) |
| 162 | + loss_mask, num_tokens = roll_tensor( |
| 163 | + loss_mask, |
| 164 | + shifts=-1, |
| 165 | + dims=-1, |
| 166 | + cp_group=self.cp_group, |
| 167 | + packed_seq_params=packed_seq_params, |
| 168 | + ) |
137 | 169 |
|
138 | | - # Compute mtp loss without storing logits to save memory. |
139 | | - mtp_loss = self.compute_output_layer_and_language_model_loss( |
140 | | - hidden_states_list[mtp_layer_number + 1], |
141 | | - labels=mtp_labels, |
142 | | - weight=self.shared_embedding_or_output_weight(), |
143 | | - sequence_parallel_enabled=self.output_layer.sequence_parallel, |
144 | | - column_parallel_linear=self.output_layer, |
145 | | - col_linear_kwargs={ |
146 | | - "weight": output_weight, |
147 | | - "runtime_gather_output": runtime_gather_output, |
148 | | - }, |
149 | | - ) |
150 | | - |
151 | | - mtp_loss = loss_mask * mtp_loss |
152 | | - if self.training: |
153 | | - # TODO(shifangx): remove the use of parallel_state here |
154 | | - # after moving loss logging to loss_func in pretrain_gpt.py |
155 | | - MTPLossLoggingHelper.save_loss_to_tracker( |
156 | | - torch.sum(mtp_loss) / num_tokens, |
157 | | - mtp_layer_number, |
158 | | - self.config.mtp_num_layers, |
159 | | - avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), |
| 170 | + # Compute mtp loss without storing logits to save memory. |
| 171 | + mtp_loss = self.compute_output_layer_and_language_model_loss( |
| 172 | + hidden_states_list[mtp_layer_number + 1], |
| 173 | + labels=mtp_labels, |
| 174 | + weight=self.shared_embedding_or_output_weight(), |
| 175 | + sequence_parallel_enabled=self.output_layer.sequence_parallel, |
| 176 | + column_parallel_linear=self.output_layer, |
| 177 | + col_linear_kwargs={ |
| 178 | + "weight": output_weight, |
| 179 | + "runtime_gather_output": runtime_gather_output, |
| 180 | + }, |
160 | 181 | ) |
161 | | - mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers |
162 | | - if self.config.calculate_per_token_loss: |
163 | | - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) |
164 | | - else: |
165 | | - hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) |
| 182 | + |
| 183 | + mtp_loss = loss_mask * mtp_loss |
| 184 | + if self.training: |
| 185 | + # TODO(shifangx): remove the use of parallel_state here |
| 186 | + # after moving loss logging to loss_func in pretrain_gpt.py |
| 187 | + MTPLossLoggingHelper.save_loss_to_tracker( |
| 188 | + torch.sum(mtp_loss) / num_tokens, |
| 189 | + mtp_layer_number, |
| 190 | + self.config.mtp_num_layers, |
| 191 | + avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True), |
| 192 | + ) |
| 193 | + mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers |
| 194 | + if self.config.calculate_per_token_loss: |
| 195 | + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss) |
| 196 | + else: |
| 197 | + hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens) |
166 | 198 |
|
167 | 199 | logits, _ = self.output_layer(hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) |
168 | 200 | # [s b h] => [b s h] |
|
0 commit comments