Skip to content

Commit 7fd48d9

Browse files
committed
support recompute's forward and backward in pipeline mode
1 parent 312fc19 commit 7fd48d9

File tree

1 file changed

+149
-14
lines changed

1 file changed

+149
-14
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 149 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222
from enum import Enum
2323
from functools import partial
2424
from typing import Callable
25+
import numpy as np
26+
import random
2527

2628
import paddle
2729
from paddle import framework
28-
30+
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
31+
get_rng_state_tracker,
32+
)
2933
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
3034
from ..utils import timer_helper as timer
3135
from ..utils.hybrid_parallel_util import (
@@ -38,6 +42,10 @@
3842
from ..utils.log_util import get_sync_logger, logger
3943
from .meta_parallel_base import MetaParallelBase
4044
from .parallel_layers.pp_layers import PipelineLayer
45+
from ..recompute.recompute import (
46+
switch_rng_state_tracker,
47+
detach_variable
48+
)
4149

4250
_use_four_directions = os.environ.get(
4351
'PADDLE_USE_FOUR_DIRECTIONS_P2P', paddle.base.core.is_compiled_with_xpu()
@@ -495,6 +503,16 @@ def __init__(self, layers, hcg, strategy):
495503
# only support user hooks during training
496504
self.user_hooks_enabled = True
497505

506+
#next layer's recompute's backward overlap with this layer's recompute's forward
507+
self.recompute_overlap = True
508+
#preserve = kwargs.pop('preserve_rng_state', True)
509+
self.preserve_rng_state = True
510+
#offload_indices = kwargs.pop('offload_indices', [])
511+
self.offload_indices =[]
512+
self.custom_get_state_func = lambda x=None: None
513+
self.custom_set_state_func = lambda x=None: None
514+
515+
498516
def register_hook(
499517
self, location: PipelineParallelMicroStepLocations, hook: Callable
500518
):
@@ -749,6 +767,90 @@ def _flush_records(self):
749767
) as f:
750768
f.writelines(record + '\n' for record in self._records)
751769
self._records = []
770+
771+
def save_state(self, state_buffers):
772+
state = {}
773+
if self.preserve_rng_state:
774+
state["fw_rng_state"] = paddle.get_rng_state()
775+
state["fwd_rng_state_tracker"] = (
776+
get_rng_state_tracker().get_states_tracker()
777+
)
778+
state[s"fwd_numpy_state"] = np.random.get_state()
779+
state["fwd_random_state"] = random.getstate()
780+
state["fwd_custom_state"] = self.custom_get_state_func()
781+
state["custom_get_state_func"] = self.custom_get_state_func
782+
state["custom_set_state_func"] = self.custom_set_state_func
783+
tracer = framework._dygraph_tracer()
784+
state["is_fw_autocast"] = (
785+
False if tracer._amp_level == framework.core.AmpLevel.O0 else True
786+
)
787+
if tracer._amp_level == framework.core.AmpLevel.O2:
788+
state["amp_level"] = 'O2'
789+
elif tracer._amp_level in (framework.core.AmpLevel.O1, framework.core.AmpLevel.O0):
790+
state["amp_level"] = 'O1'
791+
else:
792+
raise ValueError(f"unsupported amp level: {tracer._amp_level}")
793+
794+
if tracer._amp_dtype == 'float16':
795+
state["amp_dtype"] = 'float16'
796+
elif tracer._amp_dtype in ('bfloat16', 'float32'):
797+
state["amp_dtype"] = 'bfloat16'
798+
else:
799+
raise ValueError(f"unsupported amp dtype: {tracer._amp_dtype}")
800+
state["amp_white_list"], state["amp_black_list"] = tracer._get_amp_op_list()
801+
state_buffers.append(state)
802+
803+
def load_state_and_forward(self, state, input_tensor):
804+
inputs = list(input_tensor)
805+
tensor_indices = state["tensor_indices"]
806+
tensors = self.container
807+
for i, idx in enumerate(tensor_indices):
808+
inputs[idx] = (
809+
tensors[i].to(
810+
paddle.base.framework._current_expected_place()
811+
)
812+
if i in state["offload_indices"]
813+
else tensors[i]
814+
)
815+
if i in state["offload_indices"]:
816+
inputs[idx].stop_gradient = tensors[i].stop_gradient
817+
tracer = framework._dygraph_tracer()
818+
tracer._has_grad = True
819+
820+
if state["preserve_rng_state"]:
821+
with (
822+
switch_rng_state_tracker(
823+
state["fw_rng_state"],
824+
state["fwd_rng_state_tracker"],
825+
state["fwd_numpy_state"],
826+
state["fwd_random_state"],
827+
state["fwd_custom_state"],
828+
state["custom_get_state_func"],
829+
state["custom_set_state_func"],
830+
),
831+
paddle.amp.auto_cast(
832+
enable=state["is_fw_autocast"],
833+
custom_white_list=state["amp_white_list"],
834+
custom_black_list=state["amp_black_list"],
835+
level=state["amp_level"],
836+
dtype=state["amp_dtype"],
837+
),
838+
):
839+
detached_inputs = detach_variable(tuple(inputs))
840+
outputs = self._layers.forward(*detached_inputs)
841+
else:
842+
with paddle.amp.auto_cast(
843+
enable=state["is_fw_autocast"],
844+
custom_white_list=state["amp_white_list"],
845+
custom_black_list=state["amp_black_list"],
846+
level=state["amp_level"],
847+
dtype=state["amp_dtype"],
848+
):
849+
detached_inputs = detach_variable(tuple(inputs))
850+
outputs = self._layers.forward(*detached_inputs)
851+
return outputs
852+
853+
752854

753855
def forward_backward_pipeline(
754856
self,
@@ -796,6 +898,8 @@ def forward_backward_pipeline(
796898

797899
input_buffers = []
798900
output_buffers = []
901+
if self.recompute_overlap:
902+
state_buffers = []
799903

800904
micro_dataset = self._wrap_data(data)
801905

@@ -813,6 +917,8 @@ def forward_backward_pipeline(
813917
input_tensor_dict, use_dict = tuple_to_dict_helper(input_tensor)
814918

815919
self._record_stamp("F", step_id, '"B"', self._forward_color)
920+
if self.recompute_overlap:
921+
self.save_state(state_buffers)
816922
output_tensor, _, _ = self._forward_step(
817923
input_tensor=input_tensor_dict if use_dict else input_tensor,
818924
micro_dataset=micro_dataset,
@@ -856,6 +962,8 @@ def forward_backward_pipeline(
856962
self._record_stamp(
857963
"F", startup_steps + i, '"B"', self._forward_color
858964
)
965+
if self.recompute_overlap:
966+
self.save_state(state_buffers)
859967
output_tensor, _, _ = self._forward_step(
860968
input_tensor=input_tensor_dict if use_dict else input_tensor,
861969
micro_dataset=micro_dataset,
@@ -891,9 +999,16 @@ def forward_backward_pipeline(
891999
)
8921000

8931001
self._record_stamp("B", i, '"B"', self._backward_color)
894-
input_tensor_grad = self._backward_step(
895-
input_tensor, output_tensor, output_tensor_grad, step_id=i
896-
)
1002+
if self.recompute_overlap:
1003+
state = state_buffers.pop(0)
1004+
output_tensor_recompute = self.load_state_and_forward(state, input_tensor)
1005+
input_tensor_grad = self._backward_step(
1006+
input_tensor, output_tensor_recompute, output_tensor_grad, step_id=i
1007+
)
1008+
else:
1009+
input_tensor_grad = self._backward_step(
1010+
input_tensor, output_tensor, output_tensor_grad, step_id=i
1011+
)
8971012
self._record_stamp("B", i, '"E"', self._backward_color)
8981013

8991014
if last_iter:
@@ -933,12 +1048,22 @@ def forward_backward_pipeline(
9331048
self._record_stamp(
9341049
"B", steady_steps + i, '"B"', self._backward_color
9351050
)
936-
input_tensor_grad = self._backward_step(
937-
input_tensor,
938-
output_tensor,
939-
output_tensor_grad,
940-
step_id=steady_steps + i,
941-
)
1051+
if self.recompute_overlap:
1052+
state = state_buffers.pop(0)
1053+
output_tensor_recompute = self.load_state_and_forward(state, input_tensor)
1054+
input_tensor_grad = self._backward_step(
1055+
input_tensor,
1056+
output_tensor_recompute,
1057+
output_tensor_grad,
1058+
step_id=steady_steps + i,
1059+
)
1060+
else:
1061+
input_tensor_grad = self._backward_step(
1062+
input_tensor,
1063+
output_tensor,
1064+
output_tensor_grad,
1065+
step_id=steady_steps + i,
1066+
)
9421067
self._record_stamp(
9431068
"B", steady_steps + i, '"E"', self._backward_color
9441069
)
@@ -1254,11 +1379,21 @@ def _forward_step(
12541379
schedule_chunk = None
12551380
if overlap_schedule_mode:
12561381
schedule_chunk = self._layers.get_schedule_chunk(chunk_id=chunk_id)
1257-
output_tensor = schedule_chunk.forward(input_tensor)
1382+
if self.recompute_overlap:
1383+
with paddle.no_grad():
1384+
output_tensor = schedule_chunk.forward(input_tensor)
1385+
else:
1386+
output_tensor = schedule_chunk.forward(input_tensor)
12581387
else:
1259-
output_tensor = self._layers.forward(
1260-
input_tensor, chunk_id=chunk_id
1261-
)
1388+
if self.recompute_overlap:
1389+
with paddle.no_grad():
1390+
output_tensor = self._layers.forward(
1391+
input_tensor, chunk_id=chunk_id
1392+
)
1393+
else:
1394+
output_tensor = self._layers.forward(
1395+
input_tensor, chunk_id=chunk_id
1396+
)
12621397

12631398
self.callbacks.on_location(
12641399
PipelineParallelMicroStepLocations.FORWARD_END,

0 commit comments

Comments
 (0)