Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def _process_inplace_outputs(

if not copied:
# Only need a copy once.
# Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False.
raw_input_tensor.requires_grad = False
raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index])
_log_warning(
f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. "
Expand Down Expand Up @@ -449,7 +451,8 @@ def call_python_forward_function(
try:
func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name
# If this is the first time run, collect runtime tensor reuse mapping.
if kernel_invoke_id not in _GlobalOpKernelInfoMap:
is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap
if is_first_time_run:
kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id)
_GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info

Expand All @@ -473,36 +476,42 @@ def call_python_forward_function(
if tensor_input_index in inplace_map:
raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note1:
# If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
# copy for all tensors. Otherwise, we only copy the tensors whose indices are in
# tensor_input_indices_to_save_in_ctx.
# Note2:
# For inference mode, we don't need to do the copy because ctx will be None,
# so nothing will be saved for ctx.
if is_training_mode and (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
):
wrapped_arg = wrapped_arg.detach().clone()

# Only requires gradient when running under training mode
# and the associated tensor has grad_flag=True (i.e.,
# "requires_grad=True" in the original PyTorch script).
wrapped_arg.requires_grad = is_training_mode and grad_flag

# Note3:
# If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
# tensor_input_indices_for_mark_dirty.
if is_training_mode and (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
):
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
# To fix this issue:
# "a leaf Variable that requires grad has been used in an in-place operation."
# If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
# copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for
# the tensors whose indices are in tensor_input_indices_for_mark_dirty.
if is_training_mode:
if is_first_time_run:
with torch.set_grad_enabled(True):
wrapped_arg = wrapped_arg.clone()
else:
is_input_index_saved_in_ctx = (
tensor_input_indices_to_save_in_ctx is None
or tensor_input_index in tensor_input_indices_to_save_in_ctx
)
is_input_index_marked_dirty = (
tensor_input_indices_for_mark_dirty is None
or tensor_input_index in tensor_input_indices_for_mark_dirty
)
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
# when with grad, the leaf tensor after clone will not be leaf.
with torch.set_grad_enabled(is_input_index_marked_dirty):
wrapped_arg = wrapped_arg.clone()
wrapped_arg.requires_grad = is_training_mode and grad_flag

wrapped_args.append(wrapped_arg)
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
Expand Down