Skip to content

Commit 3acf47e

Browse files
authored
Merge pull request openai#2 from tlkh/finetuning
Added instructions and script for distributed training with Horovod
2 parents f645831 + 5ddd190 commit 3acf47e

File tree

2 files changed

+274
-2
lines changed

2 files changed

+274
-2
lines changed

README.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,26 @@ python3 src/interactive_conditional_samples.py -- --help
9797
To retrain GPT-2 117M model on a custom text dataset:
9898

9999
```
100-
PYTHONPATH=src ./train --dataset <file|directory|glob>
100+
PYTHONPATH=src ./train.py --dataset <file|directory|glob>
101101
```
102102

103103
If you want to precompute the dataset's encoding for multiple runs, you can instead use:
104104

105105
```
106106
PYTHONPATH=src ./encode.py <file|directory|glob> /path/to/encoded.npz
107-
PYTHONPATH=src ./train --dataset /path/to/encoded.npz
107+
PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz
108+
```
109+
110+
To do distributed on multiple GPUs or machines using Horovod:
111+
112+
```
113+
mpirun -np 4 \
114+
-H localhost:4 \
115+
-bind-to none -map-by slot \
116+
-x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \
117+
-x PYTHONPATH=src \
118+
-mca pml ob1 -mca btl ^openib \
119+
/home/jovyan/gpt-2/train-horovod.py --dataset encoded.npz
108120
```
109121

110122
## GPT-2 samples

train-horovod.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
#!/usr/bin/env python3
2+
# Usage:
3+
# PYTHONPATH=src ./train --dataset <file|directory|glob>
4+
5+
import fire
6+
import json
7+
import os
8+
import numpy as np
9+
import tensorflow as tf
10+
import random
11+
import time
12+
13+
import horovod.tensorflow as hvd
14+
15+
import model, sample, encoder
16+
17+
CHECKPOINT_DIR = 'checkpoint'
18+
SAMPLE_DIR = 'samples'
19+
20+
hvd.init()
21+
22+
def maketree(path):
23+
try:
24+
os.makedirs(path)
25+
except:
26+
pass
27+
28+
29+
def load_dataset(enc, path):
30+
paths = []
31+
if os.path.isfile(path):
32+
# Simple file
33+
paths.append(path)
34+
elif os.path.isdir(path):
35+
# Directory
36+
for (dirpath, _, fnames) in os.walk(path):
37+
for fname in fnames:
38+
paths.append(os.path.join(dirpath, fname))
39+
else:
40+
# Assume glob
41+
paths = glob.glob(path)
42+
43+
token_chunks = []
44+
for path in paths:
45+
print(str(hvd.local_rank()), 'Reading', path)
46+
if path.endswith('.npz'):
47+
# Pre-encoded
48+
with np.load(path) as npz:
49+
for item in npz.files:
50+
token_chunks.append(npz[item])
51+
else:
52+
with open(path, 'r') as fp:
53+
raw_text = fp.read()
54+
tokens = np.stack(enc.encode(raw_text))
55+
token_chunks.append(tokens)
56+
return token_chunks
57+
58+
59+
def binary_search(f, lo, hi):
60+
if f(lo) or not f(hi):
61+
return None
62+
while hi > lo + 1:
63+
mid = (lo + hi) // 2
64+
if f(mid):
65+
hi = mid
66+
else:
67+
lo = mid
68+
return hi
69+
70+
71+
class Sampler(object):
72+
"""Fairly samples a slice from a set of variable sized chunks.
73+
74+
'Fairly' means that the distribution is the same as sampling from one concatenated chunk,
75+
but without crossing chunk boundaries."""
76+
77+
def __init__(self, chunks):
78+
self.chunks = chunks
79+
self.total_size = sum(chunk.shape[0] for chunk in chunks)
80+
self.boundaries = [0]
81+
for i in range(len(chunks)):
82+
self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0])
83+
84+
def sample(self, length):
85+
assert length < self.total_size // len(
86+
self.chunks
87+
), "Dataset files are too small to sample {} tokens at a time".format(length)
88+
while True:
89+
index = random.randint(0, self.total_size - length - 1)
90+
i = binary_search(lambda j: self.boundaries[j] > index, 0,
91+
len(self.boundaries) - 1) - 1
92+
if self.boundaries[i + 1] > index + length:
93+
within_chunk = index - self.boundaries[i]
94+
return self.chunks[i][within_chunk:within_chunk + length]
95+
96+
97+
def train_main(dataset,
98+
model_name='117M',
99+
seed=None,
100+
batch_size=2,
101+
sample_length=1023,
102+
sample_num=1,
103+
sample_every=4500,
104+
run_name='run1',
105+
restore_from='latest',
106+
save_every=2000):
107+
108+
enc = encoder.get_encoder(model_name)
109+
hparams = model.default_hparams()
110+
with open(os.path.join('models', model_name, 'hparams.json')) as f:
111+
hparams.override_from_dict(json.load(f))
112+
113+
if sample_length is None:
114+
sample_length = hparams.n_ctx // 2
115+
elif sample_length > hparams.n_ctx:
116+
raise ValueError(
117+
"Can't get samples longer than window size: %s" % hparams.n_ctx)
118+
119+
# TF config
120+
121+
config = tf.ConfigProto()
122+
config.gpu_options.visible_device_list = str(hvd.local_rank())
123+
config.gpu_options.allow_growth = True
124+
125+
with tf.Session(config=config) as sess:
126+
context = tf.placeholder(tf.int32, [batch_size, None])
127+
np.random.seed(seed)
128+
tf.set_random_seed(seed)
129+
output = model.model(hparams=hparams, X=context)
130+
loss = tf.reduce_mean(
131+
tf.nn.sparse_softmax_cross_entropy_with_logits(
132+
labels=context[:, 1:], logits=output['logits'][:, :-1]))
133+
134+
tf_sample = sample.sample_sequence(
135+
hparams=hparams,
136+
length=sample_length,
137+
context=context,
138+
batch_size=batch_size,
139+
temperature=0.8,
140+
top_k=40)
141+
142+
train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
143+
144+
opt = tf.train.AdamOptimizer()
145+
opt = hvd.DistributedOptimizer(opt)
146+
train_op = opt.minimize(loss, var_list=train_vars)
147+
148+
# Horovod: broadcast initial variable states from rank 0 to all other processes.
149+
# This is necessary to ensure consistent initialization of all workers when
150+
# training is started with random weights or restored from a checkpoint.
151+
bcast = hvd.broadcast_global_variables(0)
152+
153+
saver = tf.train.Saver(
154+
var_list=train_vars,
155+
max_to_keep=5,
156+
keep_checkpoint_every_n_hours=2)
157+
158+
sess.run(tf.global_variables_initializer())
159+
160+
161+
if restore_from == 'latest':
162+
ckpt = tf.train.latest_checkpoint(
163+
os.path.join(CHECKPOINT_DIR, run_name))
164+
if ckpt is None:
165+
# Get fresh GPT weights if new run.
166+
ckpt = tf.train.latest_checkpoint(
167+
os.path.join('models', model_name))
168+
elif restore_from == 'fresh':
169+
ckpt = tf.train.latest_checkpoint(
170+
os.path.join('models', model_name))
171+
else:
172+
ckpt = tf.train.latest_checkpoint(restore_from)
173+
print(str(hvd.local_rank()), 'Loading checkpoint', ckpt)
174+
saver.restore(sess, ckpt)
175+
176+
bcast.run()
177+
178+
print(str(hvd.local_rank()), 'Loading dataset...')
179+
chunks = load_dataset(enc, dataset)
180+
data_sampler = Sampler(chunks)
181+
print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens')
182+
print(str(hvd.local_rank()), 'Training...')
183+
184+
counter = 1
185+
if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
186+
# Load the step number if we're resuming a run
187+
# Add 1 so we don't immediately try to save again
188+
with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
189+
'r') as fp:
190+
counter = int(fp.read()) + 1
191+
192+
def save():
193+
maketree(os.path.join(CHECKPOINT_DIR, run_name))
194+
print(
195+
'Saving',
196+
os.path.join(CHECKPOINT_DIR, run_name,
197+
'model-{}').format(counter))
198+
saver.save(
199+
sess,
200+
os.path.join(CHECKPOINT_DIR, run_name, 'model'),
201+
global_step=counter)
202+
with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
203+
'w') as fp:
204+
fp.write(str(counter) + '\n')
205+
206+
def generate_samples():
207+
context_tokens = data_sampler.sample(1)
208+
all_text = []
209+
index = 0
210+
while index < sample_num:
211+
out = sess.run(
212+
tf_sample, feed_dict={context: batch_size*[context_tokens]})
213+
for i in range(min(sample_num - index, batch_size)):
214+
text = enc.decode(out[i])
215+
text = '======== SAMPLE {} ========\n{}\n'.format(index + 1, text)
216+
all_text.append(text)
217+
index += 1
218+
print(text)
219+
maketree(os.path.join(SAMPLE_DIR, run_name))
220+
with open(
221+
os.path.join(SAMPLE_DIR, run_name,
222+
'samples-{}').format(counter), 'w') as fp:
223+
fp.write('\n'.join(all_text))
224+
225+
avg_loss = (0.0, 0.0)
226+
start_time = time.time()
227+
228+
try:
229+
while True:
230+
231+
batch = [data_sampler.sample(1024) for _ in range(batch_size)]
232+
233+
_, lv = sess.run((train_op, loss), feed_dict={context: batch})
234+
235+
avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0)
236+
237+
if hvd.rank() == 0:
238+
if counter % save_every == 0:
239+
save()
240+
if counter % sample_every == 0:
241+
generate_samples()
242+
243+
print(
244+
'[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
245+
.format(
246+
counter=counter,
247+
time=time.time() - start_time,
248+
loss=lv,
249+
avg=avg_loss[0] / avg_loss[1]))
250+
251+
counter += 1
252+
253+
except KeyboardInterrupt:
254+
print('interrupted')
255+
if hvd.rank() == 0:
256+
save()
257+
258+
259+
if __name__ == '__main__':
260+
fire.Fire(train_main)

0 commit comments

Comments
 (0)