Skip to content

Commit de66da1

Browse files
committed
Return tqdm to benchmark_dht
1 parent 455b035 commit de66da1

File tree

6 files changed

+146
-100
lines changed

6 files changed

+146
-100
lines changed

benchmarks/benchmark_dht.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import random
33
import time
44

5+
from tqdm import trange
6+
57
import hivemind
68
from hivemind.moe.server import declare_experts, get_experts
79
from hivemind.utils.limits import increase_file_limit
@@ -31,7 +33,7 @@ def benchmark_dht(
3133

3234
logger.info("Creating peers...")
3335
peers = []
34-
for _ in range(num_peers):
36+
for _ in trange(num_peers):
3537
neighbors = [f"0.0.0.0:{node.port}" for node in random.sample(peers, min(initial_peers, len(peers)))]
3638
peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f"0.0.0.0:*")
3739
peers.append(peer)
@@ -52,7 +54,7 @@ def benchmark_dht(
5254
benchmark_started = time.perf_counter()
5355
endpoints = []
5456

55-
for start in range(0, num_experts, expert_batch_size):
57+
for start in trange(0, num_experts, expert_batch_size):
5658
store_start = time.perf_counter()
5759
endpoints.append(random_endpoint())
5860
store_ok = declare_experts(
@@ -76,7 +78,7 @@ def benchmark_dht(
7678

7779
successful_gets = total_get_time = 0
7880

79-
for start in range(0, len(expert_uids), expert_batch_size):
81+
for start in trange(0, len(expert_uids), expert_batch_size):
8082
get_start = time.perf_counter()
8183
get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
8284
total_get_time += time.perf_counter() - get_start

examples/albert/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class AlbertTrainingArguments(TrainingArguments):
109109
gradient_accumulation_steps: int = 2
110110
seq_length: int = 512
111111

112-
max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
112+
max_steps: int = 125_000 # please note: this affects both number of steps and learning rate schedule
113113
learning_rate: float = 0.00176
114114
warmup_steps: int = 5000
115115
adam_epsilon: float = 1e-6

examples/albert/run_trainer.py

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,15 @@
1010
import transformers
1111
from datasets import load_from_disk
1212
from torch.utils.data import DataLoader
13-
from transformers import (set_seed, HfArgumentParser, TrainingArguments,
14-
DataCollatorForLanguageModeling, AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining)
13+
from transformers import (
14+
set_seed,
15+
HfArgumentParser,
16+
TrainingArguments,
17+
DataCollatorForLanguageModeling,
18+
AlbertTokenizerFast,
19+
AlbertConfig,
20+
AlbertForPreTraining,
21+
)
1522
from transformers.optimization import get_linear_schedule_with_warmup
1623
from transformers.trainer_utils import is_main_process
1724
from transformers.trainer import Trainer
@@ -23,7 +30,7 @@
2330

2431

2532
logger = logging.getLogger(__name__)
26-
LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
33+
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
2734

2835

2936
def setup_logging(training_args):
@@ -50,13 +57,13 @@ def get_model(training_args, config, tokenizer):
5057
# Find latest checkpoint in output_dir
5158
output_dir = Path(training_args.output_dir)
5259
logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
53-
latest_checkpoint_dir = max(output_dir.glob('checkpoint*'), default=None, key=os.path.getctime)
60+
latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
5461

5562
if latest_checkpoint_dir is not None:
56-
logger.info(f'Loading model from {latest_checkpoint_dir}')
63+
logger.info(f"Loading model from {latest_checkpoint_dir}")
5764
model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
5865
else:
59-
logger.info(f'Training from scratch')
66+
logger.info(f"Training from scratch")
6067
model = AlbertForPreTraining(config)
6168
model.resize_token_embeddings(len(tokenizer))
6269

@@ -87,17 +94,21 @@ def get_optimizer_and_scheduler(training_args, model):
8794
)
8895

8996
scheduler = get_linear_schedule_with_warmup(
90-
opt,
91-
num_warmup_steps=training_args.warmup_steps,
92-
num_training_steps=training_args.max_steps
97+
opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
9398
)
9499

95100
return opt, scheduler
96101

97102

98103
class CollaborativeCallback(transformers.TrainerCallback):
99-
def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer,
100-
model: torch.nn.Module, local_public_key: bytes, statistics_expiration: float):
104+
def __init__(
105+
self,
106+
dht: hivemind.DHT,
107+
optimizer: hivemind.CollaborativeOptimizer,
108+
model: torch.nn.Module,
109+
local_public_key: bytes,
110+
statistics_expiration: float,
111+
):
101112
super().__init__()
102113
self.model = model
103114
self.dht, self.collaborative_optimizer = dht, optimizer
@@ -110,21 +121,23 @@ def __init__(self, dht: hivemind.DHT, optimizer: hivemind.CollaborativeOptimizer
110121
self.loss = 0
111122
self.total_samples_processed = 0
112123

113-
def on_train_begin(self, args: TrainingArguments, state: transformers.TrainerState,
114-
control: transformers.TrainerControl, **kwargs):
115-
logger.info('Loading state from peers')
124+
def on_train_begin(
125+
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
126+
):
127+
logger.info("Loading state from peers")
116128
self.collaborative_optimizer.load_state_from_peers()
117129

118-
def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
119-
control: transformers.TrainerControl, **kwargs):
130+
def on_step_end(
131+
self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
132+
):
120133
control.should_log = True
121134
if not self.params_are_finite():
122135
self.load_from_state(self.previous_state)
123136
return control
124137
self.previous_state = self.get_current_state()
125138

126139
if state.log_history:
127-
self.loss += state.log_history[-1]['loss']
140+
self.loss += state.log_history[-1]["loss"]
128141
self.steps += 1
129142
if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
130143
self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
@@ -135,7 +148,8 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
135148
samples_per_second=samples_per_second,
136149
samples_accumulated=self.samples,
137150
loss=self.loss,
138-
mini_steps=self.steps)
151+
mini_steps=self.steps,
152+
)
139153
logger.info(f"Step {self.collaborative_optimizer.local_step}")
140154
logger.info(f"Your current contribution: {self.total_samples_processed} samples")
141155
if self.steps:
@@ -144,26 +158,26 @@ def on_step_end(self, args: TrainingArguments, state: transformers.TrainerState,
144158
self.loss = 0
145159
self.steps = 0
146160
if self.collaborative_optimizer.is_synchronized:
147-
self.dht.store(key=self.collaborative_optimizer.prefix + "_metrics",
148-
subkey=self.local_public_key, value=statistics.dict(),
149-
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
150-
return_future=True)
161+
self.dht.store(
162+
key=self.collaborative_optimizer.prefix + "_metrics",
163+
subkey=self.local_public_key,
164+
value=statistics.dict(),
165+
expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
166+
return_future=True,
167+
)
151168

152169
self.samples = self.collaborative_optimizer.local_samples_accumulated
153170

154171
return control
155172

156173
@torch.no_grad()
157174
def get_current_state(self) -> Dict[str, Any]:
158-
return {
159-
'model': self.model.state_dict(),
160-
'opt': self.collaborative_optimizer.opt.state_dict()
161-
}
175+
return {"model": self.model.state_dict(), "opt": self.collaborative_optimizer.opt.state_dict()}
162176

163177
@torch.no_grad()
164178
def load_from_state(self, state):
165-
self.model.load_state_dict(state['model'])
166-
self.collaborative_optimizer.opt.load_state_dict(state['opt'])
179+
self.model.load_state_dict(state["model"])
180+
self.collaborative_optimizer.opt.load_state_dict(state["opt"])
167181

168182
@torch.no_grad()
169183
def params_are_finite(self):
@@ -174,10 +188,10 @@ def params_are_finite(self):
174188

175189

176190
class NoOpScheduler(LRSchedulerBase):
177-
""" Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler """
191+
"""Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
178192

179193
def get_lr(self):
180-
return [group['lr'] for group in self.optimizer.param_groups]
194+
return [group["lr"] for group in self.optimizer.param_groups]
181195

182196
def print_lr(self, *args, **kwargs):
183197
if self.optimizer.scheduler:
@@ -219,53 +233,65 @@ def main():
219233

220234
opt, scheduler = get_optimizer_and_scheduler(training_args, model)
221235

222-
validators, local_public_key = metrics_utils.make_validators(
223-
collaboration_args_dict['experiment_prefix'])
236+
validators, local_public_key = metrics_utils.make_validators(collaboration_args_dict["experiment_prefix"])
224237
dht = hivemind.DHT(
225-
start=True, initial_peers=collaboration_args_dict.pop('initial_peers'),
226-
listen=not collaboration_args_dict['client_mode'],
227-
listen_on=collaboration_args_dict.pop('dht_listen_on'),
228-
endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
238+
start=True,
239+
initial_peers=collaboration_args_dict.pop("initial_peers"),
240+
listen=not collaboration_args_dict["client_mode"],
241+
listen_on=collaboration_args_dict.pop("dht_listen_on"),
242+
endpoint=collaboration_args_dict.pop("endpoint"),
243+
record_validators=validators,
244+
)
229245

230246
total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
231247
if torch.cuda.device_count() != 0:
232248
total_batch_size_per_step *= torch.cuda.device_count()
233249

234-
statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
235-
adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
236-
- collaboration_args_dict.pop('batch_size_lead')
250+
statistics_expiration = collaboration_args_dict.pop("statistics_expiration")
251+
adjusted_target_batch_size = collaboration_args_dict.pop("target_batch_size") - collaboration_args_dict.pop(
252+
"batch_size_lead"
253+
)
237254

238255
collaborative_optimizer = hivemind.CollaborativeOptimizer(
239-
opt=opt, dht=dht, scheduler=scheduler, prefix=collaboration_args_dict.pop('experiment_prefix'),
240-
compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop('compression')),
241-
batch_size_per_step=total_batch_size_per_step, throughput=collaboration_args_dict.pop('bandwidth'),
242-
target_batch_size=adjusted_target_batch_size, client_mode=collaboration_args_dict.pop('client_mode'),
243-
verbose=True, start=True, **collaboration_args_dict
256+
opt=opt,
257+
dht=dht,
258+
scheduler=scheduler,
259+
prefix=collaboration_args_dict.pop("experiment_prefix"),
260+
compression_type=hivemind.utils.CompressionType.Value(collaboration_args_dict.pop("compression")),
261+
batch_size_per_step=total_batch_size_per_step,
262+
throughput=collaboration_args_dict.pop("bandwidth"),
263+
target_batch_size=adjusted_target_batch_size,
264+
client_mode=collaboration_args_dict.pop("client_mode"),
265+
verbose=True,
266+
start=True,
267+
**collaboration_args_dict,
244268
)
245269

246270
class TrainerWithIndependentShuffling(Trainer):
247271
def get_train_dataloader(self) -> DataLoader:
248-
""" Shuffle data independently for each peer to avoid duplicating batches [important for quality] """
272+
"""Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
249273
torch.manual_seed(hash(local_public_key))
250274
return super().get_train_dataloader()
251275

252276
trainer = TrainerWithIndependentShuffling(
253-
model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator,
277+
model=model,
278+
args=training_args,
279+
tokenizer=tokenizer,
280+
data_collator=data_collator,
254281
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
255282
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
256283
optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
257-
callbacks=[CollaborativeCallback(
258-
dht, collaborative_optimizer, model, local_public_key, statistics_expiration)]
284+
callbacks=[
285+
CollaborativeCallback(dht, collaborative_optimizer, model, local_public_key, statistics_expiration)
286+
],
259287
)
260288
trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
261289
trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
262290

263291
# Training
264292
if training_args.do_train:
265293
latest_checkpoint_dir = max(
266-
Path(training_args.output_dir).glob('checkpoint*'),
267-
default=None,
268-
key=os.path.getctime
294+
Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
269295
)
270296

271297
trainer.train(model_path=latest_checkpoint_dir)

0 commit comments

Comments
 (0)