Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Contributed in collaboration with RedNote.

Memory is often the limiting factor for very large sparse MoE models such as DeepSeek-V3 and Qwen3-235B. Fine-grained recomputation lowers activation memory at the cost of extra compute. Offloading can use host-device bandwidth so that reload overlaps compute and keeps overhead small in many setups. Fine-grained activation offloading moves activations at module granularity so you can tune how much activation memory leaves the device and adjust training throughput.

Supported offloading modules are `"attn_norm"`, `"core_attn"`, `"attn_proj"`, `"mlp_norm"`, `"expert_fc1"`, and `"moe_act"`. They can be combined with fine-grained recomputation to free almost all activations for a transformer layer on the device.
Supported offloading modules are `"attn_norm"`, `"core_attn"`, `"attn_proj"`, `"mlp_norm"`, `"expert_fc1"`, `"moe_act"`, and `"group_mlp"`. They can be combined with fine-grained recomputation to free almost all activations for a transformer layer on the device. Use `"group_mlp"` for TE op-fuser GroupedMLP, where FC1, activation, router-prob scaling, and FC2 are fused and therefore offloaded as a single group instead of separate `"expert_fc1"` and `"moe_act"` groups.

## Features

Expand All @@ -33,10 +33,15 @@ Supported offloading modules are `"attn_norm"`, `"core_attn"`, `"attn_proj"`, `"
--fine-grained-activation-offloading

# Modules whose inputs are offloaded (refer to your training script for list or delimiter syntax).
# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act",
# "group_mlp".
--offload-modules expert_fc1
```

`group_mlp` requires MoE, `--moe-grouped-gemm`, and `--use-transformer-engine-op-fuser`. It cannot be combined with `expert_fc1` or `moe_act` because the TE op-fuser path does not expose those two internal boundaries.
When the TE op fuser saves `GroupedTensor` activations, offloading moves the grouped tensor backing buffers such as row/column data, scales, and offsets independently and rebuilds the grouped wrapper on reload.
The minimum offload size is applied to each `GroupedTensor` backing buffer independently, so small scale or metadata buffers stay on GPU while large data buffers are offloaded.

## Compatible With Fine-Grained Recomputation

- For low-overhead modules such as LayerNorm or `moe_act`, use recomputation to save activation memory.
Expand Down
24 changes: 24 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,30 @@ def set_save_original_input(module):
)


try:
from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor as _TEGroupedTensor
except ImportError:
_TEGroupedTensor = None


def is_te_grouped_tensor(tensor: Any) -> bool:
"""Return whether a tensor is Transformer Engine's GroupedTensor wrapper."""
return _TEGroupedTensor is not None and isinstance(tensor, _TEGroupedTensor)


def te_grouped_tensor_prepare_for_saving(tensor: torch.Tensor):
"""Use TE's canonical GroupedTensor saved-buffer layout."""
if not is_te_grouped_tensor(tensor):
raise TypeError(f"Expected TE GroupedTensor, got {type(tensor).__name__}")
return tensor.prepare_for_saving()


def te_grouped_tensor_restore_from_saved(tensor_obj: Any, tensors: list[Optional[torch.Tensor]]):
"""Restore a TE GroupedTensor from buffers produced by prepare_for_saving."""
tensor_obj.restore_from_saved(list(tensors))
return tensor_obj


try:
# pylint: disable=unused-import
from transformer_engine.pytorch import cpu_offload_v1 as cpu_offload
Expand Down
146 changes: 127 additions & 19 deletions megatron/core/pipeline_parallel/fine_grained_activation_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
DEBUG = False
DEBUG_RANK = 0

from megatron.core.extensions.transformer_engine import (
is_te_grouped_tensor,
te_grouped_tensor_prepare_for_saving,
te_grouped_tensor_restore_from_saved,
)
from megatron.core.transformer.cuda_graphs import is_graph_capturing
from megatron.core.utils import nvtx_range_pop, nvtx_range_push

_TE_GROUPED_TENSOR_STATE = "te_grouped_tensor"
_TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE = "te_grouped_tensor_resident_buffer"


def debug_rank(message):
"""Print debug message for a specific rank when DEBUG is enabled."""
Expand Down Expand Up @@ -95,6 +103,48 @@ def print_offload_summary_table(total_offload_bytes: Dict[str, int]):
torch.distributed.barrier()


def _tensor_allows_offloading(tensor):
"""Return whether a tensor's optional offload preference allows offloading."""
return not (hasattr(tensor, "offloading_activation") and not tensor.offloading_activation)


def _regular_tensor_needs_offloading(tensor, min_offloaded_tensor_size):
"""Return whether a regular tensor meets the offload policy."""
return tensor.numel() >= min_offloaded_tensor_size and _tensor_allows_offloading(tensor)


def _is_te_grouped_tensor_state(state):
"""Return whether state contains offloaded TE GroupedTensor buffers."""
return isinstance(state, tuple) and len(state) > 0 and state[0] == _TE_GROUPED_TENSOR_STATE


def _te_grouped_tensor_state_offload_nbytes(state):
"""Return bytes actually offloaded for a TE GroupedTensor state."""
return state[3]


def _record_tensor_stream(tensor, stream):
"""Record stream usage for regular tensors.

TE GroupedTensor buffers are recorded individually when the selected backing buffers are
offloaded.
"""
if is_te_grouped_tensor(tensor):
return
tensor.record_stream(stream)


def _release_tensor_storage(tensor):
"""Release storage for regular tensors.

TE GroupedTensor wrappers do not own a single storage to release here; their backing buffers
are either offloaded individually or left resident.
"""
if is_te_grouped_tensor(tensor):
return
tensor.untyped_storage().resize_(0)


class GPUTensorPool:
"""
GPU memory pool for efficient allocation and deallocation of tensors.
Expand Down Expand Up @@ -343,9 +393,8 @@ def __init__(self, name):
self.total_offload_bytes = 0
self.total_tensor_count = 0
# Using memory pool is for the compatibility with cuda graph.
# Shapes of tensors for expert_fc1 and moe_act are not known in advance,
# so we do not use CPU pool for them.
if name == "expert_fc1" or name == "moe_act":
# Shapes of MoE expert tensors depend on token routing, so we do not use CPU pool.
if name in {"expert_fc1", "moe_act", "group_mlp"}:
self.use_cpu_pool = False
else:
self.use_cpu_pool = True
Expand Down Expand Up @@ -374,9 +423,14 @@ def wait_reload_event(self, stream):
"""Wait for the reload event."""
stream.wait_event(self._reload_event)

def update_offload_info(self, tensor):
def update_offload_info(self, tensor_or_state):
"""Update the offload information."""
self.total_offload_bytes += tensor.numel() * tensor.element_size()
if _is_te_grouped_tensor_state(tensor_or_state):
self.total_offload_bytes += _te_grouped_tensor_state_offload_nbytes(tensor_or_state)
elif is_te_grouped_tensor(tensor_or_state):
self.total_offload_bytes += 0
else:
self.total_offload_bytes += tensor_or_state.numel() * tensor_or_state.element_size()
self.total_tensor_count += 1


Expand Down Expand Up @@ -731,8 +785,8 @@ class ChunkOffloadHandler:
Manages tensor groups, coordinates asynchronous GPU-CPU transfers, and handles synchronization.
"""

def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True):
"""Offload."""
def _offload_tensor(self, src_tensor, pin_memory=True, use_cpu_pool=True):
"""Offload a regular torch.Tensor."""
debug_rank("--------offload")

if not src_tensor.is_contiguous():
Expand All @@ -749,8 +803,8 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True):
state = (src_tensor.device, cpu_backup, use_cpu_pool)
return state

def reload(self, state, non_blocking=None):
"""Reload."""
def _reload_tensor(self, state, non_blocking=None):
"""Reload a regular torch.Tensor."""
debug_rank("------reload")
dev, cpu_backup, use_cpu_pool = state
if non_blocking is None:
Expand All @@ -763,6 +817,61 @@ def reload(self, state, non_blocking=None):
self.cpu_tensor_pool.free(cpu_backup)
return gpu_tensor

def _offload_te_grouped_tensor(self, src_tensor, pin_memory=True, use_cpu_pool=True):
"""Offload selected TE GroupedTensor backing buffers without stacking members."""
debug_rank("--------offload TE GroupedTensor")

saved_tensors, tensor_obj = te_grouped_tensor_prepare_for_saving(src_tensor)
buffer_states = []
offload_nbytes = 0
for buffer in saved_tensors:
if not isinstance(buffer, torch.Tensor):
buffer_states.append(None)
elif _regular_tensor_needs_offloading(buffer, self.min_offloaded_tensor_size):
buffer_states.append(
self._offload_tensor(buffer, pin_memory=pin_memory, use_cpu_pool=use_cpu_pool)
)
buffer.record_stream(self.d2h_stream)
offload_nbytes += buffer.numel() * buffer.element_size()
else:
buffer_states.append((_TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE, buffer))
Comment on lines +828 to +837

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This implementation is correct, but we could make it a bit nicer by using a consistent structure for offloaded and non-offloaded buffers.

Suggested change
if not isinstance(buffer, torch.Tensor):
buffer_states.append(None)
elif _regular_tensor_needs_offloading(buffer, self.min_offloaded_tensor_size):
buffer_states.append(
self._offload_tensor(buffer, pin_memory=pin_memory, use_cpu_pool=use_cpu_pool)
)
buffer.record_stream(self.d2h_stream)
offload_nbytes += buffer.numel() * buffer.element_size()
else:
buffer_states.append((_TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE, buffer))
is_offloaded = (
buffer is not None
and _regular_tensor_needs_offloading(buffer, self.min_offloaded_tensor_size)
)
buffer_state = (
self._offload_tensor(buffer, pin_memory, use_cpu_pool=use_cpu_pool)
if is_offloaded
else buffer
)
buffer_states.append((is_offloaded, buffer_state))

The reloading function becomes:

    def _reload_te_grouped_tensor(self, state, non_blocking=None):
        """Reload TE GroupedTensor backing buffers and reconstruct the wrapper."""
        debug_rank("------reload TE GroupedTensor")
        _, tensor_obj, buffer_states, _ = state

        buffers = []
        for is_offloaded, buffer_state in buffer_states:
            if is_offloaded:
                buffers.append(self._reload_tensor(buffer_state, non_blocking=non_blocking))
            else:
                buffers.append(buffer_state)

        return te_grouped_tensor_restore_from_saved(tensor_obj, buffers)


if offload_nbytes == 0:
return src_tensor
Comment on lines +839 to +840

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is correct. GroupedTensor.prepare_for_saving removes all the buffers within the GroupedTensor, so we need to call restore_from_saved even if no buffers were moved to CPU.

return (_TE_GROUPED_TENSOR_STATE, tensor_obj, buffer_states, offload_nbytes)

def _reload_te_grouped_tensor(self, state, non_blocking=None):
"""Reload TE GroupedTensor backing buffers and reconstruct the wrapper."""
debug_rank("------reload TE GroupedTensor")
_, tensor_obj, buffer_states, _ = state

buffers = []
for buffer_state in buffer_states:
if buffer_state is None:
buffers.append(None)
elif (
isinstance(buffer_state, tuple)
and len(buffer_state) > 0
and buffer_state[0] == _TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE
):
buffers.append(buffer_state[1])
else:
buffers.append(self._reload_tensor(buffer_state, non_blocking=non_blocking))

return te_grouped_tensor_restore_from_saved(tensor_obj, buffers)

def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True):
"""Offload a tensor-like saved activation."""
if is_te_grouped_tensor(src_tensor):
return self._offload_te_grouped_tensor(src_tensor, pin_memory, use_cpu_pool)
return self._offload_tensor(src_tensor, pin_memory, use_cpu_pool)
Comment on lines +863 to +867

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks correct, but I'm curious why we don't have similar logic for QuantizedTensor, which should have similar problems as GroupedTensor. If we already have TE infrastructure for CPU offloading for quantized tensors, then we should use that for grouped tensors instead of implementing custom logic in Mcore.


def reload(self, state, non_blocking=None):
"""Reload a tensor-like saved activation."""
if _is_te_grouped_tensor_state(state):
return self._reload_te_grouped_tensor(state, non_blocking)
return self._reload_tensor(state, non_blocking)

def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool):
self.do_offload = True

Expand Down Expand Up @@ -863,12 +972,9 @@ def tensor_need_offloading_checker(self, tensor):
debug_rank(
f"tensor_need_offloading_checker {getattr(tensor, 'offloading_activation', None)}"
)
if tensor.numel() < self.min_offloaded_tensor_size:
return False
# Respect tensor's offload preference if specified
if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation:
return False
return True
if is_te_grouped_tensor(tensor):
return _tensor_allows_offloading(tensor)
return _regular_tensor_needs_offloading(tensor, self.min_offloaded_tensor_size)

def bulk_offload_group(self):
"""offload a group of tensors recorded in tensor_push()."""
Expand All @@ -883,8 +989,10 @@ def bulk_offload_group(self):
tensor_on_device, use_cpu_pool=group_to_offload.use_cpu_pool
)
if self.is_warmup:
group_to_offload.update_offload_info(tensor_on_device)
tensor_on_device.record_stream(self.d2h_stream)
group_to_offload.update_offload_info(
state if is_te_grouped_tensor(tensor_on_device) else tensor_on_device
)
_record_tensor_stream(tensor_on_device, self.d2h_stream)
group_to_offload.push_tensor(tensor_tag, state)
group_to_offload.record_offload_event(self.d2h_stream)
self._groups_to_offload.pop()
Expand Down Expand Up @@ -964,8 +1072,8 @@ def bulk_offload(self, forced_released_tensors):
for release_tensor in forced_released_tensors:
if self.tensor_need_offloading_checker(release_tensor):
# Ensure tensor is not in use before freeing
release_tensor.record_stream(cur_stream)
release_tensor.untyped_storage().resize_(0)
_record_tensor_stream(release_tensor, cur_stream)
_release_tensor_storage(release_tensor)

def on_group_commit_forward(self, forced_released_tensors):
"""Called at the end of a layer group's forward pass to trigger offloading."""
Expand Down
18 changes: 15 additions & 3 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import make_weak_ref

def _set_skip_fp8_weight_update(value: bool):
"""Set TE's FP8 weight update skip flag across TE API versions."""
if hasattr(FP8GlobalStateManager, "set_skip_fp8_weight_update_tensor"):
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(value)
return
qstate = FP8GlobalStateManager.quantization_state
if qstate.skip_fp8_weight_update_tensor is None:
qstate.skip_fp8_weight_update_tensor = torch.empty(
1, dtype=torch.float32, device="cuda"
)
qstate.skip_fp8_weight_update_tensor.fill_(value)

HAVE_TE_GRAPHS = True
except:
HAVE_TE_GRAPHS = False
Expand Down Expand Up @@ -608,7 +620,7 @@ def forward(ctx, runner, is_first_microbatch, *inputs):
# Note that FP8GlobalStateManager.is_first_fp8_module() is inacccurate as each
# layer may be in its own fp8 context, when the fp8 recipe != delayed_scaling
if runner.is_first_layer and (runner.fp8_param_cache_updated != is_first_microbatch):
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(not is_first_microbatch)
_set_skip_fp8_weight_update(not is_first_microbatch)
runner.fp8_param_cache_updated = is_first_microbatch

runner.fwd_graph.replay()
Expand Down Expand Up @@ -738,13 +750,13 @@ def __init__(

if self.fp8_enabled:
self.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
_set_skip_fp8_weight_update(False)

if self.fp4_enabled:
from megatron.core.fp4_utils import get_fp4_recipe # to avoid circular import

self.fp4_recipe = get_fp4_recipe(self.base_module.config)
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)
_set_skip_fp8_weight_update(False)

def __str__(self):
return "%s; hid %s" % (
Expand Down
6 changes: 5 additions & 1 deletion megatron/core/transformer/moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,13 @@ Unlike recomputation (which trades compute for memory), offloading trades **GPU-
**Usage**
```bash
--fine-grained-activation-offloading
--offload-modules expert_fc1 moe_act # Choices: attn_norm, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act
--offload-modules expert_fc1 moe_act # Choices: attn_norm, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act, group_mlp
```

Use `group_mlp` instead of `expert_fc1`/`moe_act` when `--use-transformer-engine-op-fuser`
is enabled for TE GroupedMLP; the fused path offloads saved activations for the whole
GroupedMLP as one group.

For more details, see `docs/user-guide/features/fine_grained_activation_offloading.md`

### Communication Optimization
Expand Down
30 changes: 23 additions & 7 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def __init__(
and "moe_act" in self.config.offload_modules
)

self.offload_group_mlp = (
self.config.fine_grained_activation_offloading
and "group_mlp" in self.config.offload_modules
)

self.activation_recompute = (
self.config.recompute_granularity == 'selective'
and "moe_act" in self.config.recompute_modules
Expand Down Expand Up @@ -314,7 +319,7 @@ def _is_fused_impl_supported(self) -> bool:
if self.tp_group.size() > 1:
return False # Tensor parallelism is not supported
if self.offload_expert_fc1 or self.offload_moe_act:
return False # Fine-grained activation offloading is not supported
return False # expert_fc1/moe_act offloading needs non-fused boundaries
if self.config.moe_apply_probs_on_input:
return False # Pre-multiplying probs is not supported

Expand Down Expand Up @@ -528,13 +533,24 @@ def _fused_forward(
tokens_per_expert, dtype=torch.int, device=permuted_probs.device
)

# Call fused impl
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU
tokens_per_expert, # FC2
group_mlp_manager = off_interface(
self.offload_group_mlp, permuted_local_hidden_states, "group_mlp"
)
with group_mlp_manager as permuted_local_hidden_states:
# Call fused impl. With group_mlp offload enabled, TE op-fuser saved tensors
# are captured by the active saved_tensors_hooks and offloaded as one group.
output = ops(
permuted_local_hidden_states,
tokens_per_expert, # FC1
permuted_probs, # Scaled SwiGLU/SReLU
tokens_per_expert, # FC2
)
if self.offload_group_mlp:
output = off_interface.group_commit(
output,
name="group_mlp",
forced_released_tensors=[permuted_local_hidden_states],
)

# Remove padding if needed
if unpadded_tokens_per_expert is not None:
Expand Down
Loading