Skip to content

Commit 444a0ed

Browse files
authored
Avoid one time clone to save memory peak (#17934)
### Avoid one more time clone to save memory peak
1 parent 009cd4e commit 444a0ed

1 file changed

Lines changed: 32 additions & 23 deletions

File tree

orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ def _process_inplace_outputs(
245245

246246
if not copied:
247247
# Only need a copy once.
248+
# Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False.
249+
raw_input_tensor.requires_grad = False
248250
raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index])
249251
_log_warning(
250252
f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. "
@@ -449,7 +451,8 @@ def call_python_forward_function(
449451
try:
450452
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
451453
# If this is the first time run, collect runtime tensor reuse mapping.
452-
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
454+
is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap
455+
if is_first_time_run:
453456
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
454457
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info
455458

@@ -473,36 +476,42 @@ def call_python_forward_function(
473476
if tensor_input_index in inplace_map:
474477
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg
475478

479+
# Only requires gradient when running under training mode
480+
# and the associated tensor has grad_flag=True (i.e.,
481+
# "requires_grad=True" in the original PyTorch script).
482+
wrapped_arg.requires_grad = is_training_mode and grad_flag
483+
476484
# Note1:
477485
# If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
478486
# copy for all tensors. Otherwise, we only copy the tensors whose indices are in
479487
# tensor_input_indices_to_save_in_ctx.
480488
# Note2:
481489
# For inference mode, we don't need to do the copy because ctx will be None,
482490
# so nothing will be saved for ctx.
483-
if is_training_mode and (
484-
tensor_input_indices_to_save_in_ctx is None
485-
or tensor_input_index in tensor_input_indices_to_save_in_ctx
486-
):
487-
wrapped_arg = wrapped_arg.detach().clone()
488-
489-
# Only requires gradient when running under training mode
490-
# and the associated tensor has grad_flag=True (i.e.,
491-
# "requires_grad=True" in the original PyTorch script).
492-
wrapped_arg.requires_grad = is_training_mode and grad_flag
493-
494491
# Note3:
495-
# If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
496-
# mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
497-
# tensor_input_indices_for_mark_dirty.
498-
if is_training_mode and (
499-
tensor_input_indices_for_mark_dirty is None
500-
or tensor_input_index in tensor_input_indices_for_mark_dirty
501-
):
502-
# To fix this issue:
503-
# "a leaf Variable that requires grad has been used in an in-place operation."
504-
with torch.set_grad_enabled(True):
505-
wrapped_arg = wrapped_arg.clone()
492+
# To fix this issue:
493+
# "a leaf Variable that requires grad has been used in an in-place operation."
494+
# If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
495+
# copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for
496+
# the tensors whose indices are in tensor_input_indices_for_mark_dirty.
497+
if is_training_mode:
498+
if is_first_time_run:
499+
with torch.set_grad_enabled(True):
500+
wrapped_arg = wrapped_arg.clone()
501+
else:
502+
is_input_index_saved_in_ctx = (
503+
tensor_input_indices_to_save_in_ctx is None
504+
or tensor_input_index in tensor_input_indices_to_save_in_ctx
505+
)
506+
is_input_index_marked_dirty = (
507+
tensor_input_indices_for_mark_dirty is None
508+
or tensor_input_index in tensor_input_indices_for_mark_dirty
509+
)
510+
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
511+
# when with grad, the leaf tensor after clone will not be leaf.
512+
with torch.set_grad_enabled(is_input_index_marked_dirty):
513+
wrapped_arg = wrapped_arg.clone()
514+
wrapped_arg.requires_grad = is_training_mode and grad_flag
506515

507516
wrapped_args.append(wrapped_arg)
508517
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg

0 commit comments

Comments
 (0)