2222from enum import Enum
2323from functools import partial
2424from typing import Callable
25+ import numpy as np
26+ import random
2527
2628import paddle
2729from paddle import framework
28-
30+ from paddle .distributed .fleet .meta_parallel .parallel_layers .random import (
31+ get_rng_state_tracker ,
32+ )
2933from ..meta_optimizers .dygraph_optimizer import HybridParallelOptimizer
3034from ..utils import timer_helper as timer
3135from ..utils .hybrid_parallel_util import (
3842from ..utils .log_util import get_sync_logger , logger
3943from .meta_parallel_base import MetaParallelBase
4044from .parallel_layers .pp_layers import PipelineLayer
45+ from ..recompute .recompute import (
46+ switch_rng_state_tracker ,
47+ detach_variable
48+ )
4149
4250_use_four_directions = os .environ .get (
4351 'PADDLE_USE_FOUR_DIRECTIONS_P2P' , paddle .base .core .is_compiled_with_xpu ()
@@ -495,6 +503,16 @@ def __init__(self, layers, hcg, strategy):
495503 # only support user hooks during training
496504 self .user_hooks_enabled = True
497505
506+ #next layer's recompute's backward overlap with this layer's recompute's forward
507+ self .recompute_overlap = True
508+ #preserve = kwargs.pop('preserve_rng_state', True)
509+ self .preserve_rng_state = True
510+ #offload_indices = kwargs.pop('offload_indices', [])
511+ self .offload_indices = []
512+ self .custom_get_state_func = lambda x = None : None
513+ self .custom_set_state_func = lambda x = None : None
514+
515+
498516 def register_hook (
499517 self , location : PipelineParallelMicroStepLocations , hook : Callable
500518 ):
@@ -749,6 +767,90 @@ def _flush_records(self):
749767 ) as f :
750768 f .writelines (record + '\n ' for record in self ._records )
751769 self ._records = []
770+
771+ def save_state (self , state_buffers ):
772+ state = {}
773+ if self .preserve_rng_state :
774+ state ["fw_rng_state" ] = paddle .get_rng_state ()
775+ state ["fwd_rng_state_tracker" ] = (
776+ get_rng_state_tracker ().get_states_tracker ()
777+ )
778+ state [s "fwd_numpy_state" ] = np .random .get_state ()
779+ state ["fwd_random_state" ] = random .getstate ()
780+ state ["fwd_custom_state" ] = self .custom_get_state_func ()
781+ state ["custom_get_state_func" ] = self .custom_get_state_func
782+ state ["custom_set_state_func" ] = self .custom_set_state_func
783+ tracer = framework ._dygraph_tracer ()
784+ state ["is_fw_autocast" ] = (
785+ False if tracer ._amp_level == framework .core .AmpLevel .O0 else True
786+ )
787+ if tracer ._amp_level == framework .core .AmpLevel .O2 :
788+ state ["amp_level" ] = 'O2'
789+ elif tracer ._amp_level in (framework .core .AmpLevel .O1 , framework .core .AmpLevel .O0 ):
790+ state ["amp_level" ] = 'O1'
791+ else :
792+ raise ValueError (f"unsupported amp level: { tracer ._amp_level } " )
793+
794+ if tracer ._amp_dtype == 'float16' :
795+ state ["amp_dtype" ] = 'float16'
796+ elif tracer ._amp_dtype in ('bfloat16' , 'float32' ):
797+ state ["amp_dtype" ] = 'bfloat16'
798+ else :
799+ raise ValueError (f"unsupported amp dtype: { tracer ._amp_dtype } " )
800+ state ["amp_white_list" ], state ["amp_black_list" ] = tracer ._get_amp_op_list ()
801+ state_buffers .append (state )
802+
803+ def load_state_and_forward (self , state , input_tensor ):
804+ inputs = list (input_tensor )
805+ tensor_indices = state ["tensor_indices" ]
806+ tensors = self .container
807+ for i , idx in enumerate (tensor_indices ):
808+ inputs [idx ] = (
809+ tensors [i ].to (
810+ paddle .base .framework ._current_expected_place ()
811+ )
812+ if i in state ["offload_indices" ]
813+ else tensors [i ]
814+ )
815+ if i in state ["offload_indices" ]:
816+ inputs [idx ].stop_gradient = tensors [i ].stop_gradient
817+ tracer = framework ._dygraph_tracer ()
818+ tracer ._has_grad = True
819+
820+ if state ["preserve_rng_state" ]:
821+ with (
822+ switch_rng_state_tracker (
823+ state ["fw_rng_state" ],
824+ state ["fwd_rng_state_tracker" ],
825+ state ["fwd_numpy_state" ],
826+ state ["fwd_random_state" ],
827+ state ["fwd_custom_state" ],
828+ state ["custom_get_state_func" ],
829+ state ["custom_set_state_func" ],
830+ ),
831+ paddle .amp .auto_cast (
832+ enable = state ["is_fw_autocast" ],
833+ custom_white_list = state ["amp_white_list" ],
834+ custom_black_list = state ["amp_black_list" ],
835+ level = state ["amp_level" ],
836+ dtype = state ["amp_dtype" ],
837+ ),
838+ ):
839+ detached_inputs = detach_variable (tuple (inputs ))
840+ outputs = self ._layers .forward (* detached_inputs )
841+ else :
842+ with paddle .amp .auto_cast (
843+ enable = state ["is_fw_autocast" ],
844+ custom_white_list = state ["amp_white_list" ],
845+ custom_black_list = state ["amp_black_list" ],
846+ level = state ["amp_level" ],
847+ dtype = state ["amp_dtype" ],
848+ ):
849+ detached_inputs = detach_variable (tuple (inputs ))
850+ outputs = self ._layers .forward (* detached_inputs )
851+ return outputs
852+
853+
752854
753855 def forward_backward_pipeline (
754856 self ,
@@ -796,6 +898,8 @@ def forward_backward_pipeline(
796898
797899 input_buffers = []
798900 output_buffers = []
901+ if self .recompute_overlap :
902+ state_buffers = []
799903
800904 micro_dataset = self ._wrap_data (data )
801905
@@ -813,6 +917,8 @@ def forward_backward_pipeline(
813917 input_tensor_dict , use_dict = tuple_to_dict_helper (input_tensor )
814918
815919 self ._record_stamp ("F" , step_id , '"B"' , self ._forward_color )
920+ if self .recompute_overlap :
921+ self .save_state (state_buffers )
816922 output_tensor , _ , _ = self ._forward_step (
817923 input_tensor = input_tensor_dict if use_dict else input_tensor ,
818924 micro_dataset = micro_dataset ,
@@ -856,6 +962,8 @@ def forward_backward_pipeline(
856962 self ._record_stamp (
857963 "F" , startup_steps + i , '"B"' , self ._forward_color
858964 )
965+ if self .recompute_overlap :
966+ self .save_state (state_buffers )
859967 output_tensor , _ , _ = self ._forward_step (
860968 input_tensor = input_tensor_dict if use_dict else input_tensor ,
861969 micro_dataset = micro_dataset ,
@@ -891,9 +999,16 @@ def forward_backward_pipeline(
891999 )
8921000
8931001 self ._record_stamp ("B" , i , '"B"' , self ._backward_color )
894- input_tensor_grad = self ._backward_step (
895- input_tensor , output_tensor , output_tensor_grad , step_id = i
896- )
1002+ if self .recompute_overlap :
1003+ state = state_buffers .pop (0 )
1004+ output_tensor_recompute = self .load_state_and_forward (state , input_tensor )
1005+ input_tensor_grad = self ._backward_step (
1006+ input_tensor , output_tensor_recompute , output_tensor_grad , step_id = i
1007+ )
1008+ else :
1009+ input_tensor_grad = self ._backward_step (
1010+ input_tensor , output_tensor , output_tensor_grad , step_id = i
1011+ )
8971012 self ._record_stamp ("B" , i , '"E"' , self ._backward_color )
8981013
8991014 if last_iter :
@@ -933,12 +1048,22 @@ def forward_backward_pipeline(
9331048 self ._record_stamp (
9341049 "B" , steady_steps + i , '"B"' , self ._backward_color
9351050 )
936- input_tensor_grad = self ._backward_step (
937- input_tensor ,
938- output_tensor ,
939- output_tensor_grad ,
940- step_id = steady_steps + i ,
941- )
1051+ if self .recompute_overlap :
1052+ state = state_buffers .pop (0 )
1053+ output_tensor_recompute = self .load_state_and_forward (state , input_tensor )
1054+ input_tensor_grad = self ._backward_step (
1055+ input_tensor ,
1056+ output_tensor_recompute ,
1057+ output_tensor_grad ,
1058+ step_id = steady_steps + i ,
1059+ )
1060+ else :
1061+ input_tensor_grad = self ._backward_step (
1062+ input_tensor ,
1063+ output_tensor ,
1064+ output_tensor_grad ,
1065+ step_id = steady_steps + i ,
1066+ )
9421067 self ._record_stamp (
9431068 "B" , steady_steps + i , '"E"' , self ._backward_color
9441069 )
@@ -1254,11 +1379,21 @@ def _forward_step(
12541379 schedule_chunk = None
12551380 if overlap_schedule_mode :
12561381 schedule_chunk = self ._layers .get_schedule_chunk (chunk_id = chunk_id )
1257- output_tensor = schedule_chunk .forward (input_tensor )
1382+ if self .recompute_overlap :
1383+ with paddle .no_grad ():
1384+ output_tensor = schedule_chunk .forward (input_tensor )
1385+ else :
1386+ output_tensor = schedule_chunk .forward (input_tensor )
12581387 else :
1259- output_tensor = self ._layers .forward (
1260- input_tensor , chunk_id = chunk_id
1261- )
1388+ if self .recompute_overlap :
1389+ with paddle .no_grad ():
1390+ output_tensor = self ._layers .forward (
1391+ input_tensor , chunk_id = chunk_id
1392+ )
1393+ else :
1394+ output_tensor = self ._layers .forward (
1395+ input_tensor , chunk_id = chunk_id
1396+ )
12621397
12631398 self .callbacks .on_location (
12641399 PipelineParallelMicroStepLocations .FORWARD_END ,
0 commit comments