Skip to content

Commit 00f4598

Browse files
committed
fix
1 parent 666a761 commit 00f4598

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,10 @@ def call_python_forward_function(
508508
or tensor_input_index in tensor_input_indices_for_mark_dirty
509509
)
510510
if is_input_index_saved_in_ctx or is_input_index_marked_dirty:
511+
# when with grad, the leaf tensor after clone will not be leaf.
511512
with torch.set_grad_enabled(is_input_index_marked_dirty):
512513
wrapped_arg = wrapped_arg.clone()
514+
wrapped_arg.requires_grad = is_training_mode and grad_flag
513515

514516
wrapped_args.append(wrapped_arg)
515517
input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg

0 commit comments

Comments
 (0)