Skip to content

Commit c9e5348

Browse files
committed
Handle SymbolicStaticFunction signature check in fleet.recompute
1 parent 7b673ca commit c9e5348

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

python/paddle/distributed/fleet/recompute/recompute.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
get_rng_state_tracker,
3232
)
3333
from paddle.framework import core, in_dynamic_mode
34+
from paddle.jit.dy2static.program_translator import SymbolicStaticFunction
3435

3536
from ..utils.log_util import logger
3637

@@ -685,10 +686,14 @@ def recompute(function, *args, **kwargs):
685686
offload_indices = kwargs.pop('offload_indices', [])
686687
input_args = []
687688
# rearrange `position-args + keyword-args` into `position-args`
688-
if isinstance(function, paddle.nn.Layer):
689-
dyfunc_sig = inspect.signature(function.forward)
690-
else:
691-
dyfunc_sig = inspect.signature(function)
689+
target = (
690+
function.forward
691+
if isinstance(function, paddle.nn.Layer)
692+
else function
693+
)
694+
if isinstance(target, SymbolicStaticFunction):
695+
target = target.dygraph_function
696+
dyfunc_sig = inspect.signature(target)
692697

693698
bound_args = dyfunc_sig.bind(*args, **kwargs)
694699
bound_args.apply_defaults()

0 commit comments

Comments
 (0)