10
10
import transformers
11
11
from datasets import load_from_disk
12
12
from torch .utils .data import DataLoader
13
- from transformers import (set_seed , HfArgumentParser , TrainingArguments ,
14
- DataCollatorForLanguageModeling , AlbertTokenizerFast , AlbertConfig , AlbertForPreTraining )
13
+ from transformers import (
14
+ set_seed ,
15
+ HfArgumentParser ,
16
+ TrainingArguments ,
17
+ DataCollatorForLanguageModeling ,
18
+ AlbertTokenizerFast ,
19
+ AlbertConfig ,
20
+ AlbertForPreTraining ,
21
+ )
15
22
from transformers .optimization import get_linear_schedule_with_warmup
16
23
from transformers .trainer_utils import is_main_process
17
24
from transformers .trainer import Trainer
23
30
24
31
25
32
logger = logging .getLogger (__name__ )
26
- LRSchedulerBase = getattr (torch .optim .lr_scheduler , ' _LRScheduler' , None )
33
+ LRSchedulerBase = getattr (torch .optim .lr_scheduler , " _LRScheduler" , None )
27
34
28
35
29
36
def setup_logging (training_args ):
@@ -50,13 +57,13 @@ def get_model(training_args, config, tokenizer):
50
57
# Find latest checkpoint in output_dir
51
58
output_dir = Path (training_args .output_dir )
52
59
logger .info (f'Checkpoint dir { output_dir } , contents { list (output_dir .glob ("checkpoint*" ))} ' )
53
- latest_checkpoint_dir = max (output_dir .glob (' checkpoint*' ), default = None , key = os .path .getctime )
60
+ latest_checkpoint_dir = max (output_dir .glob (" checkpoint*" ), default = None , key = os .path .getctime )
54
61
55
62
if latest_checkpoint_dir is not None :
56
- logger .info (f' Loading model from { latest_checkpoint_dir } ' )
63
+ logger .info (f" Loading model from { latest_checkpoint_dir } " )
57
64
model = AlbertForPreTraining .from_pretrained (latest_checkpoint_dir )
58
65
else :
59
- logger .info (f' Training from scratch' )
66
+ logger .info (f" Training from scratch" )
60
67
model = AlbertForPreTraining (config )
61
68
model .resize_token_embeddings (len (tokenizer ))
62
69
@@ -87,17 +94,21 @@ def get_optimizer_and_scheduler(training_args, model):
87
94
)
88
95
89
96
scheduler = get_linear_schedule_with_warmup (
90
- opt ,
91
- num_warmup_steps = training_args .warmup_steps ,
92
- num_training_steps = training_args .max_steps
97
+ opt , num_warmup_steps = training_args .warmup_steps , num_training_steps = training_args .max_steps
93
98
)
94
99
95
100
return opt , scheduler
96
101
97
102
98
103
class CollaborativeCallback (transformers .TrainerCallback ):
99
- def __init__ (self , dht : hivemind .DHT , optimizer : hivemind .CollaborativeOptimizer ,
100
- model : torch .nn .Module , local_public_key : bytes , statistics_expiration : float ):
104
+ def __init__ (
105
+ self ,
106
+ dht : hivemind .DHT ,
107
+ optimizer : hivemind .CollaborativeOptimizer ,
108
+ model : torch .nn .Module ,
109
+ local_public_key : bytes ,
110
+ statistics_expiration : float ,
111
+ ):
101
112
super ().__init__ ()
102
113
self .model = model
103
114
self .dht , self .collaborative_optimizer = dht , optimizer
@@ -110,21 +121,23 @@ def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer
110
121
self .loss = 0
111
122
self .total_samples_processed = 0
112
123
113
- def on_train_begin (self , args : TrainingArguments , state : transformers .TrainerState ,
114
- control : transformers .TrainerControl , ** kwargs ):
115
- logger .info ('Loading state from peers' )
124
+ def on_train_begin (
125
+ self , args : TrainingArguments , state : transformers .TrainerState , control : transformers .TrainerControl , ** kwargs
126
+ ):
127
+ logger .info ("Loading state from peers" )
116
128
self .collaborative_optimizer .load_state_from_peers ()
117
129
118
- def on_step_end (self , args : TrainingArguments , state : transformers .TrainerState ,
119
- control : transformers .TrainerControl , ** kwargs ):
130
+ def on_step_end (
131
+ self , args : TrainingArguments , state : transformers .TrainerState , control : transformers .TrainerControl , ** kwargs
132
+ ):
120
133
control .should_log = True
121
134
if not self .params_are_finite ():
122
135
self .load_from_state (self .previous_state )
123
136
return control
124
137
self .previous_state = self .get_current_state ()
125
138
126
139
if state .log_history :
127
- self .loss += state .log_history [- 1 ][' loss' ]
140
+ self .loss += state .log_history [- 1 ][" loss" ]
128
141
self .steps += 1
129
142
if self .collaborative_optimizer .local_step != self .last_reported_collaboration_step :
130
143
self .last_reported_collaboration_step = self .collaborative_optimizer .local_step
@@ -135,7 +148,8 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
135
148
samples_per_second = samples_per_second ,
136
149
samples_accumulated = self .samples ,
137
150
loss = self .loss ,
138
- mini_steps = self .steps )
151
+ mini_steps = self .steps ,
152
+ )
139
153
logger .info (f"Step { self .collaborative_optimizer .local_step } " )
140
154
logger .info (f"Your current contribution: { self .total_samples_processed } samples" )
141
155
if self .steps :
@@ -144,26 +158,26 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
144
158
self .loss = 0
145
159
self .steps = 0
146
160
if self .collaborative_optimizer .is_synchronized :
147
- self .dht .store (key = self .collaborative_optimizer .prefix + "_metrics" ,
148
- subkey = self .local_public_key , value = statistics .dict (),
149
- expiration_time = hivemind .get_dht_time () + self .statistics_expiration ,
150
- return_future = True )
161
+ self .dht .store (
162
+ key = self .collaborative_optimizer .prefix + "_metrics" ,
163
+ subkey = self .local_public_key ,
164
+ value = statistics .dict (),
165
+ expiration_time = hivemind .get_dht_time () + self .statistics_expiration ,
166
+ return_future = True ,
167
+ )
151
168
152
169
self .samples = self .collaborative_optimizer .local_samples_accumulated
153
170
154
171
return control
155
172
156
173
@torch .no_grad ()
157
174
def get_current_state (self ) -> Dict [str , Any ]:
158
- return {
159
- 'model' : self .model .state_dict (),
160
- 'opt' : self .collaborative_optimizer .opt .state_dict ()
161
- }
175
+ return {"model" : self .model .state_dict (), "opt" : self .collaborative_optimizer .opt .state_dict ()}
162
176
163
177
@torch .no_grad ()
164
178
def load_from_state (self , state ):
165
- self .model .load_state_dict (state [' model' ])
166
- self .collaborative_optimizer .opt .load_state_dict (state [' opt' ])
179
+ self .model .load_state_dict (state [" model" ])
180
+ self .collaborative_optimizer .opt .load_state_dict (state [" opt" ])
167
181
168
182
@torch .no_grad ()
169
183
def params_are_finite (self ):
@@ -174,10 +188,10 @@ def params_are_finite(self):
174
188
175
189
176
190
class NoOpScheduler (LRSchedulerBase ):
177
- """ Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
191
+ """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
178
192
179
193
def get_lr (self ):
180
- return [group ['lr' ] for group in self .optimizer .param_groups ]
194
+ return [group ["lr" ] for group in self .optimizer .param_groups ]
181
195
182
196
def print_lr (self , * args , ** kwargs ):
183
197
if self .optimizer .scheduler :
@@ -219,53 +233,65 @@ def main():
219
233
220
234
opt , scheduler = get_optimizer_and_scheduler (training_args , model )
221
235
222
- validators , local_public_key = metrics_utils .make_validators (
223
- collaboration_args_dict ['experiment_prefix' ])
236
+ validators , local_public_key = metrics_utils .make_validators (collaboration_args_dict ["experiment_prefix" ])
224
237
dht = hivemind .DHT (
225
- start = True , initial_peers = collaboration_args_dict .pop ('initial_peers' ),
226
- listen = not collaboration_args_dict ['client_mode' ],
227
- listen_on = collaboration_args_dict .pop ('dht_listen_on' ),
228
- endpoint = collaboration_args_dict .pop ('endpoint' ), record_validators = validators )
238
+ start = True ,
239
+ initial_peers = collaboration_args_dict .pop ("initial_peers" ),
240
+ listen = not collaboration_args_dict ["client_mode" ],
241
+ listen_on = collaboration_args_dict .pop ("dht_listen_on" ),
242
+ endpoint = collaboration_args_dict .pop ("endpoint" ),
243
+ record_validators = validators ,
244
+ )
229
245
230
246
total_batch_size_per_step = training_args .per_device_train_batch_size * training_args .gradient_accumulation_steps
231
247
if torch .cuda .device_count () != 0 :
232
248
total_batch_size_per_step *= torch .cuda .device_count ()
233
249
234
- statistics_expiration = collaboration_args_dict .pop ('statistics_expiration' )
235
- adjusted_target_batch_size = collaboration_args_dict .pop ('target_batch_size' ) \
236
- - collaboration_args_dict .pop ('batch_size_lead' )
250
+ statistics_expiration = collaboration_args_dict .pop ("statistics_expiration" )
251
+ adjusted_target_batch_size = collaboration_args_dict .pop ("target_batch_size" ) - collaboration_args_dict .pop (
252
+ "batch_size_lead"
253
+ )
237
254
238
255
collaborative_optimizer = hivemind .CollaborativeOptimizer (
239
- opt = opt , dht = dht , scheduler = scheduler , prefix = collaboration_args_dict .pop ('experiment_prefix' ),
240
- compression_type = hivemind .utils .CompressionType .Value (collaboration_args_dict .pop ('compression' )),
241
- batch_size_per_step = total_batch_size_per_step , throughput = collaboration_args_dict .pop ('bandwidth' ),
242
- target_batch_size = adjusted_target_batch_size , client_mode = collaboration_args_dict .pop ('client_mode' ),
243
- verbose = True , start = True , ** collaboration_args_dict
256
+ opt = opt ,
257
+ dht = dht ,
258
+ scheduler = scheduler ,
259
+ prefix = collaboration_args_dict .pop ("experiment_prefix" ),
260
+ compression_type = hivemind .utils .CompressionType .Value (collaboration_args_dict .pop ("compression" )),
261
+ batch_size_per_step = total_batch_size_per_step ,
262
+ throughput = collaboration_args_dict .pop ("bandwidth" ),
263
+ target_batch_size = adjusted_target_batch_size ,
264
+ client_mode = collaboration_args_dict .pop ("client_mode" ),
265
+ verbose = True ,
266
+ start = True ,
267
+ ** collaboration_args_dict ,
244
268
)
245
269
246
270
class TrainerWithIndependentShuffling (Trainer ):
247
271
def get_train_dataloader (self ) -> DataLoader :
248
- """ Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
272
+ """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
249
273
torch .manual_seed (hash (local_public_key ))
250
274
return super ().get_train_dataloader ()
251
275
252
276
trainer = TrainerWithIndependentShuffling (
253
- model = model , args = training_args , tokenizer = tokenizer , data_collator = data_collator ,
277
+ model = model ,
278
+ args = training_args ,
279
+ tokenizer = tokenizer ,
280
+ data_collator = data_collator ,
254
281
train_dataset = tokenized_datasets ["train" ] if training_args .do_train else None ,
255
282
eval_dataset = tokenized_datasets ["validation" ] if training_args .do_eval else None ,
256
283
optimizers = (collaborative_optimizer , NoOpScheduler (collaborative_optimizer )),
257
- callbacks = [CollaborativeCallback (
258
- dht , collaborative_optimizer , model , local_public_key , statistics_expiration )]
284
+ callbacks = [
285
+ CollaborativeCallback (dht , collaborative_optimizer , model , local_public_key , statistics_expiration )
286
+ ],
259
287
)
260
288
trainer .remove_callback (transformers .trainer_callback .PrinterCallback )
261
289
trainer .remove_callback (transformers .trainer_callback .ProgressCallback )
262
290
263
291
# Training
264
292
if training_args .do_train :
265
293
latest_checkpoint_dir = max (
266
- Path (training_args .output_dir ).glob ('checkpoint*' ),
267
- default = None ,
268
- key = os .path .getctime
294
+ Path (training_args .output_dir ).glob ("checkpoint*" ), default = None , key = os .path .getctime
269
295
)
270
296
271
297
trainer .train (model_path = latest_checkpoint_dir )
0 commit comments