@@ -245,6 +245,8 @@ def _process_inplace_outputs(
245245
246246 if not copied :
247247 # Only need a copy once.
248+ # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False.
249+ raw_input_tensor .requires_grad = False
248250 raw_input_tensor .copy_ (all_outputs_of_kernel_run [output_index ])
249251 _log_warning (
250252 f"{ log_prefix } Copy output tensor { output_index } to raw input tensor { raw_tensor_input_index } . "
@@ -449,7 +451,8 @@ def call_python_forward_function(
449451 try :
450452 func_name = func_name .decode ("utf-8" ) if isinstance (func_name , bytes ) else func_name
451453 # If this is the first time run, collect runtime tensor reuse mapping.
452- if kernel_invoke_id not in _GlobalOpKernelInfoMap :
454+ is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap
455+ if is_first_time_run :
453456 kernel_info = CustomFuncOpKernelInfo (kernel_invoke_id )
454457 _GlobalOpKernelInfoMap [kernel_invoke_id ] = kernel_info
455458
@@ -473,36 +476,42 @@ def call_python_forward_function(
473476 if tensor_input_index in inplace_map :
474477 raw_input_tensors_used_inplace [tensor_input_index ] = wrapped_arg
475478
479+ # Only requires gradient when running under training mode
480+ # and the associated tensor has grad_flag=True (i.e.,
481+ # "requires_grad=True" in the original PyTorch script).
482+ wrapped_arg .requires_grad = is_training_mode and grad_flag
483+
476484 # Note1:
477485 # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the
478486 # copy for all tensors. Otherwise, we only copy the tensors whose indices are in
479487 # tensor_input_indices_to_save_in_ctx.
480488 # Note2:
481489 # For inference mode, we don't need to do the copy because ctx will be None,
482490 # so nothing will be saved for ctx.
483- if is_training_mode and (
484- tensor_input_indices_to_save_in_ctx is None
485- or tensor_input_index in tensor_input_indices_to_save_in_ctx
486- ):
487- wrapped_arg = wrapped_arg .detach ().clone ()
488-
489- # Only requires gradient when running under training mode
490- # and the associated tensor has grad_flag=True (i.e.,
491- # "requires_grad=True" in the original PyTorch script).
492- wrapped_arg .requires_grad = is_training_mode and grad_flag
493-
494491 # Note3:
495- # If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
496- # mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in
497- # tensor_input_indices_for_mark_dirty.
498- if is_training_mode and (
499- tensor_input_indices_for_mark_dirty is None
500- or tensor_input_index in tensor_input_indices_for_mark_dirty
501- ):
502- # To fix this issue:
503- # "a leaf Variable that requires grad has been used in an in-place operation."
504- with torch .set_grad_enabled (True ):
505- wrapped_arg = wrapped_arg .clone ()
492+ # To fix this issue:
493+ # "a leaf Variable that requires grad has been used in an in-place operation."
494+ # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the
495+ # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for
496+ # the tensors whose indices are in tensor_input_indices_for_mark_dirty.
497+ if is_training_mode :
498+ if is_first_time_run :
499+ with torch .set_grad_enabled (True ):
500+ wrapped_arg = wrapped_arg .clone ()
501+ else :
502+ is_input_index_saved_in_ctx = (
503+ tensor_input_indices_to_save_in_ctx is None
504+ or tensor_input_index in tensor_input_indices_to_save_in_ctx
505+ )
506+ is_input_index_marked_dirty = (
507+ tensor_input_indices_for_mark_dirty is None
508+ or tensor_input_index in tensor_input_indices_for_mark_dirty
509+ )
510+ if is_input_index_saved_in_ctx or is_input_index_marked_dirty :
511+ # when with grad, the leaf tensor after clone will not be leaf.
512+ with torch .set_grad_enabled (is_input_index_marked_dirty ):
513+ wrapped_arg = wrapped_arg .clone ()
514+ wrapped_arg .requires_grad = is_training_mode and grad_flag
506515
507516 wrapped_args .append (wrapped_arg )
508517 input_tensors_used_for_fw_run [tensor_input_index ] = wrapped_arg
0 commit comments