@@ -312,46 +312,42 @@ def static_pylayer(forward_fn, inputs, backward_fn=None, name=None):
312312 Examples:
313313 .. code-block:: python
314314
315- >>> import paddle
316- >>> import numpy as np
317-
318- >>> paddle.enable_static()
319-
320- >>> def forward_fn(x):
321- ... return paddle.exp(x)
322-
323- >>> def backward_fn(dy):
324- ... return 2 * paddle.exp(dy)
325-
326- >>> main_program = paddle.static.Program()
327- >>> start_program = paddle.static.Program()
328-
329- >>> place = paddle.CPUPlace()
330- >>> exe = paddle.static.Executor(place)
331- >>> with paddle.static.program_guard(main_program, start_program):
332- ... data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
333- ... data.stop_gradient = False
334- ... ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
335- ... data_grad = paddle.static.gradients([ret], data)[0]
336-
337- >>> exe.run(start_program)
338- >>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32) # type: ignore[var-annotated]
339- >>> x, x_grad, y = exe.run(
340- ... main_program,
341- ... feed={"X": x},
342- ... fetch_list=[
343- ... data.name,
344- ... data_grad.name,
345- ... ret.name
346- ... ],
347- ... )
348-
349- >>> print(x)
350- [[1. 2. 3. 4. 5.]]
351- >>> print(x_grad)
352- [[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
353- >>> print(y)
354- [[ 2.7182817 7.389056 20.085537 54.59815 148.41316 ]]
315+ >>> import paddle
316+ >>> import numpy as np
317+
318+ >>> paddle.enable_static()
319+
320+ >>> def forward_fn(x):
321+ ... return paddle.exp(x)
322+
323+ >>> def backward_fn(dy):
324+ ... return 2 * paddle.exp(dy)
325+
326+ >>> main_program = paddle.static.Program()
327+ >>> start_program = paddle.static.Program()
328+
329+ >>> place = paddle.CPUPlace()
330+ >>> exe = paddle.static.Executor(place)
331+ >>> with paddle.static.program_guard(main_program, start_program):
332+ ... data = paddle.static.data(name="X", shape=[None, 5], dtype="float32")
333+ ... data.stop_gradient = False
334+ ... ret = paddle.static.nn.static_pylayer(forward_fn, [data], backward_fn)
335+ ... data_grad = paddle.static.gradients([ret], data)[0]
336+
337+ >>> exe.run(start_program)
338+ >>> x = np.array([[1.0, 2.0, 3.0, 4.0, 5.0]], dtype=np.float32)
339+ >>> x, x_grad, y = exe.run(
340+ ... main_program,
341+ ... feed={"X": x},
342+ ... fetch_list=[data, data_grad, ret],
343+ ... )
344+
345+ >>> print(x)
346+ [[1. 2. 3. 4. 5.]]
347+ >>> print(x_grad)
348+ [[5.4365635 5.4365635 5.4365635 5.4365635 5.4365635]]
349+ >>> print(y)
350+ [[ 2.7182817 7.389056 20.085537 54.59815 148.41316 ]]
355351 """
356352 assert (
357353 in_dygraph_mode () is False
0 commit comments