We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 666a761 commit 00f4598Copy full SHA for 00f4598
orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py
@@ -508,8 +508,10 @@ def call_python_forward_function(
508
or tensor_input_index in tensor_input_indices_for_mark_dirty
509
)
510
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
511
+ # when with grad, the leaf tensor after clone will not be leaf.
512
with torch.set_grad_enabled(is_input_index_marked_dirty):
513
wrapped_arg = wrapped_arg.clone()
514
+ wrapped_arg.requires_grad = is_training_mode and grad_flag
515
516
wrapped_args.append(wrapped_arg)
517
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg
0 commit comments