-
Notifications
You must be signed in to change notification settings - Fork 0
Support group_mlp offload for TE op fuser #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: srelu-fused-grouped-mlp
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,9 +10,17 @@ | |
| DEBUG = False | ||
| DEBUG_RANK = 0 | ||
|
|
||
| from megatron.core.extensions.transformer_engine import ( | ||
| is_te_grouped_tensor, | ||
| te_grouped_tensor_prepare_for_saving, | ||
| te_grouped_tensor_restore_from_saved, | ||
| ) | ||
| from megatron.core.transformer.cuda_graphs import is_graph_capturing | ||
| from megatron.core.utils import nvtx_range_pop, nvtx_range_push | ||
|
|
||
| _TE_GROUPED_TENSOR_STATE = "te_grouped_tensor" | ||
| _TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE = "te_grouped_tensor_resident_buffer" | ||
|
|
||
|
|
||
| def debug_rank(message): | ||
| """Print debug message for a specific rank when DEBUG is enabled.""" | ||
|
|
@@ -95,6 +103,48 @@ def print_offload_summary_table(total_offload_bytes: Dict[str, int]): | |
| torch.distributed.barrier() | ||
|
|
||
|
|
||
| def _tensor_allows_offloading(tensor): | ||
| """Return whether a tensor's optional offload preference allows offloading.""" | ||
| return not (hasattr(tensor, "offloading_activation") and not tensor.offloading_activation) | ||
|
|
||
|
|
||
| def _regular_tensor_needs_offloading(tensor, min_offloaded_tensor_size): | ||
| """Return whether a regular tensor meets the offload policy.""" | ||
| return tensor.numel() >= min_offloaded_tensor_size and _tensor_allows_offloading(tensor) | ||
|
|
||
|
|
||
| def _is_te_grouped_tensor_state(state): | ||
| """Return whether state contains offloaded TE GroupedTensor buffers.""" | ||
| return isinstance(state, tuple) and len(state) > 0 and state[0] == _TE_GROUPED_TENSOR_STATE | ||
|
|
||
|
|
||
| def _te_grouped_tensor_state_offload_nbytes(state): | ||
| """Return bytes actually offloaded for a TE GroupedTensor state.""" | ||
| return state[3] | ||
|
|
||
|
|
||
| def _record_tensor_stream(tensor, stream): | ||
| """Record stream usage for regular tensors. | ||
|
|
||
| TE GroupedTensor buffers are recorded individually when the selected backing buffers are | ||
| offloaded. | ||
| """ | ||
| if is_te_grouped_tensor(tensor): | ||
| return | ||
| tensor.record_stream(stream) | ||
|
|
||
|
|
||
| def _release_tensor_storage(tensor): | ||
| """Release storage for regular tensors. | ||
|
|
||
| TE GroupedTensor wrappers do not own a single storage to release here; their backing buffers | ||
| are either offloaded individually or left resident. | ||
| """ | ||
| if is_te_grouped_tensor(tensor): | ||
| return | ||
| tensor.untyped_storage().resize_(0) | ||
|
|
||
|
|
||
| class GPUTensorPool: | ||
| """ | ||
| GPU memory pool for efficient allocation and deallocation of tensors. | ||
|
|
@@ -343,9 +393,8 @@ def __init__(self, name): | |
| self.total_offload_bytes = 0 | ||
| self.total_tensor_count = 0 | ||
| # Using memory pool is for the compatibility with cuda graph. | ||
| # Shapes of tensors for expert_fc1 and moe_act are not known in advance, | ||
| # so we do not use CPU pool for them. | ||
| if name == "expert_fc1" or name == "moe_act": | ||
| # Shapes of MoE expert tensors depend on token routing, so we do not use CPU pool. | ||
| if name in {"expert_fc1", "moe_act", "group_mlp"}: | ||
| self.use_cpu_pool = False | ||
| else: | ||
| self.use_cpu_pool = True | ||
|
|
@@ -374,9 +423,14 @@ def wait_reload_event(self, stream): | |
| """Wait for the reload event.""" | ||
| stream.wait_event(self._reload_event) | ||
|
|
||
| def update_offload_info(self, tensor): | ||
| def update_offload_info(self, tensor_or_state): | ||
| """Update the offload information.""" | ||
| self.total_offload_bytes += tensor.numel() * tensor.element_size() | ||
| if _is_te_grouped_tensor_state(tensor_or_state): | ||
| self.total_offload_bytes += _te_grouped_tensor_state_offload_nbytes(tensor_or_state) | ||
| elif is_te_grouped_tensor(tensor_or_state): | ||
| self.total_offload_bytes += 0 | ||
| else: | ||
| self.total_offload_bytes += tensor_or_state.numel() * tensor_or_state.element_size() | ||
| self.total_tensor_count += 1 | ||
|
|
||
|
|
||
|
|
@@ -731,8 +785,8 @@ class ChunkOffloadHandler: | |
| Manages tensor groups, coordinates asynchronous GPU-CPU transfers, and handles synchronization. | ||
| """ | ||
|
|
||
| def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): | ||
| """Offload.""" | ||
| def _offload_tensor(self, src_tensor, pin_memory=True, use_cpu_pool=True): | ||
| """Offload a regular torch.Tensor.""" | ||
| debug_rank("--------offload") | ||
|
|
||
| if not src_tensor.is_contiguous(): | ||
|
|
@@ -749,8 +803,8 @@ def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): | |
| state = (src_tensor.device, cpu_backup, use_cpu_pool) | ||
| return state | ||
|
|
||
| def reload(self, state, non_blocking=None): | ||
| """Reload.""" | ||
| def _reload_tensor(self, state, non_blocking=None): | ||
| """Reload a regular torch.Tensor.""" | ||
| debug_rank("------reload") | ||
| dev, cpu_backup, use_cpu_pool = state | ||
| if non_blocking is None: | ||
|
|
@@ -763,6 +817,61 @@ def reload(self, state, non_blocking=None): | |
| self.cpu_tensor_pool.free(cpu_backup) | ||
| return gpu_tensor | ||
|
|
||
| def _offload_te_grouped_tensor(self, src_tensor, pin_memory=True, use_cpu_pool=True): | ||
| """Offload selected TE GroupedTensor backing buffers without stacking members.""" | ||
| debug_rank("--------offload TE GroupedTensor") | ||
|
|
||
| saved_tensors, tensor_obj = te_grouped_tensor_prepare_for_saving(src_tensor) | ||
| buffer_states = [] | ||
| offload_nbytes = 0 | ||
| for buffer in saved_tensors: | ||
| if not isinstance(buffer, torch.Tensor): | ||
| buffer_states.append(None) | ||
| elif _regular_tensor_needs_offloading(buffer, self.min_offloaded_tensor_size): | ||
| buffer_states.append( | ||
| self._offload_tensor(buffer, pin_memory=pin_memory, use_cpu_pool=use_cpu_pool) | ||
| ) | ||
| buffer.record_stream(self.d2h_stream) | ||
| offload_nbytes += buffer.numel() * buffer.element_size() | ||
| else: | ||
| buffer_states.append((_TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE, buffer)) | ||
|
|
||
| if offload_nbytes == 0: | ||
| return src_tensor | ||
|
Comment on lines
+839
to
+840
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure this is correct. |
||
| return (_TE_GROUPED_TENSOR_STATE, tensor_obj, buffer_states, offload_nbytes) | ||
|
|
||
| def _reload_te_grouped_tensor(self, state, non_blocking=None): | ||
| """Reload TE GroupedTensor backing buffers and reconstruct the wrapper.""" | ||
| debug_rank("------reload TE GroupedTensor") | ||
| _, tensor_obj, buffer_states, _ = state | ||
|
|
||
| buffers = [] | ||
| for buffer_state in buffer_states: | ||
| if buffer_state is None: | ||
| buffers.append(None) | ||
| elif ( | ||
| isinstance(buffer_state, tuple) | ||
| and len(buffer_state) > 0 | ||
| and buffer_state[0] == _TE_GROUPED_TENSOR_RESIDENT_BUFFER_STATE | ||
| ): | ||
| buffers.append(buffer_state[1]) | ||
| else: | ||
| buffers.append(self._reload_tensor(buffer_state, non_blocking=non_blocking)) | ||
|
|
||
| return te_grouped_tensor_restore_from_saved(tensor_obj, buffers) | ||
|
|
||
| def offload(self, src_tensor, pin_memory=True, use_cpu_pool=True): | ||
| """Offload a tensor-like saved activation.""" | ||
| if is_te_grouped_tensor(src_tensor): | ||
| return self._offload_te_grouped_tensor(src_tensor, pin_memory, use_cpu_pool) | ||
| return self._offload_tensor(src_tensor, pin_memory, use_cpu_pool) | ||
|
Comment on lines
+863
to
+867
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks correct, but I'm curious why we don't have similar logic for |
||
|
|
||
| def reload(self, state, non_blocking=None): | ||
| """Reload a tensor-like saved activation.""" | ||
| if _is_te_grouped_tensor_state(state): | ||
| return self._reload_te_grouped_tensor(state, non_blocking) | ||
| return self._reload_tensor(state, non_blocking) | ||
|
|
||
| def __init__(self, min_offloaded_tensor_size, cpu_tensor_pool): | ||
| self.do_offload = True | ||
|
|
||
|
|
@@ -863,12 +972,9 @@ def tensor_need_offloading_checker(self, tensor): | |
| debug_rank( | ||
| f"tensor_need_offloading_checker {getattr(tensor, 'offloading_activation', None)}" | ||
| ) | ||
| if tensor.numel() < self.min_offloaded_tensor_size: | ||
| return False | ||
| # Respect tensor's offload preference if specified | ||
| if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: | ||
| return False | ||
| return True | ||
| if is_te_grouped_tensor(tensor): | ||
| return _tensor_allows_offloading(tensor) | ||
| return _regular_tensor_needs_offloading(tensor, self.min_offloaded_tensor_size) | ||
|
|
||
| def bulk_offload_group(self): | ||
| """offload a group of tensors recorded in tensor_push().""" | ||
|
|
@@ -883,8 +989,10 @@ def bulk_offload_group(self): | |
| tensor_on_device, use_cpu_pool=group_to_offload.use_cpu_pool | ||
| ) | ||
| if self.is_warmup: | ||
| group_to_offload.update_offload_info(tensor_on_device) | ||
| tensor_on_device.record_stream(self.d2h_stream) | ||
| group_to_offload.update_offload_info( | ||
| state if is_te_grouped_tensor(tensor_on_device) else tensor_on_device | ||
| ) | ||
| _record_tensor_stream(tensor_on_device, self.d2h_stream) | ||
| group_to_offload.push_tensor(tensor_tag, state) | ||
| group_to_offload.record_offload_event(self.d2h_stream) | ||
| self._groups_to_offload.pop() | ||
|
|
@@ -964,8 +1072,8 @@ def bulk_offload(self, forced_released_tensors): | |
| for release_tensor in forced_released_tensors: | ||
| if self.tensor_need_offloading_checker(release_tensor): | ||
| # Ensure tensor is not in use before freeing | ||
| release_tensor.record_stream(cur_stream) | ||
| release_tensor.untyped_storage().resize_(0) | ||
| _record_tensor_stream(release_tensor, cur_stream) | ||
| _release_tensor_storage(release_tensor) | ||
|
|
||
| def on_group_commit_forward(self, forced_released_tensors): | ||
| """Called at the end of a layer group's forward pass to trigger offloading.""" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: This implementation is correct, but we could make it a bit nicer by using a consistent structure for offloaded and non-offloaded buffers.
The reloading function becomes: