File tree Expand file tree Collapse file tree 1 file changed +18
-5
lines changed
Expand file tree Collapse file tree 1 file changed +18
-5
lines changed Original file line number Diff line number Diff line change 1+ import contextlib
12import copy
23import os
34import random
67import numpy as np
78import torch
89
9- from .utils import deprecate
10+ from .utils import deprecate , is_transformers_available
11+
12+
13+ if is_transformers_available ():
14+ import transformers
1015
1116
1217def enable_full_determinism (seed : int ):
@@ -197,11 +202,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
197202 self .cur_decay_value = decay
198203 one_minus_decay = 1 - decay
199204
205+ context_manager = contextlib .nullcontext
206+ if is_transformers_available () and transformers .deepspeed .is_deepspeed_zero3_enabled ():
207+ import deepspeed
208+
200209 for s_param , param in zip (self .shadow_params , parameters ):
201- if param .requires_grad :
202- s_param .sub_ (one_minus_decay * (s_param - param ))
203- else :
204- s_param .copy_ (param )
210+ if is_transformers_available () and transformers .deepspeed .is_deepspeed_zero3_enabled ():
211+ context_manager = deepspeed .zero .GatheredParameters (param , modifier_rank = None )
212+
213+ with context_manager ():
214+ if param .requires_grad :
215+ s_param .sub_ (one_minus_decay * (s_param - param ))
216+ else :
217+ s_param .copy_ (param )
205218
206219 def copy_to (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
207220 """
You can’t perform that action at this time.
0 commit comments