Skip to content

Avoid one time clone to save memory peak #17934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 21, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def _process_inplace_outputs(

if not copied:
# Only need a copy once.
# Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False.
raw_input_tensor.requires_grad = False
raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index])
_log_warning(
f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. "
Expand Down Expand Up @@ -449,7 +451,8 @@ def call_python_forward_function(
try:
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
# If this is the first time run, collect runtime tensor reuse mapping.
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap
if is_first_time_run:
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info

Expand All @@ -473,36 +476,40 @@ def call_python_forward_function(
if tensor_input_index in inplace_map:
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note1:
# If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
# copy for all tensors. Otherwise, we only copy the tensors whose indices are in
# tensor_input_indices_to_save_in_ctx.
# Note2:
# For inference mode, we don't need to do the copy because ctx will be None,
# so nothing will be saved for ctx.
if is_training_mode and (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
):
wrapped_arg = wrapped_arg.detach().clone()

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note3:
# If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
# tensor_input_indices_for_mark_dirty.
if is_training_mode and (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
):
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
# If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for
# the tensors whose indices are in tensor_input_indices_for_mark_dirty.
if is_training_mode:
if is_first_time_run:
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
else:
is_input_index_saved_in_ctx = (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
)
is_input_index_marked_dirty = (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
)
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
with torch.set_grad_enabled(is_input_index_marked_dirty):
wrapped_arg = wrapped_arg.clone()

wrapped_args.append(wrapped_arg)
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
Expand Down