|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Usage: |
| 3 | +# PYTHONPATH=src ./train --dataset <file|directory|glob> |
| 4 | + |
| 5 | +import fire |
| 6 | +import json |
| 7 | +import os |
| 8 | +import numpy as np |
| 9 | +import tensorflow as tf |
| 10 | +import random |
| 11 | +import time |
| 12 | + |
| 13 | +import horovod.tensorflow as hvd |
| 14 | + |
| 15 | +import model, sample, encoder |
| 16 | + |
| 17 | +CHECKPOINT_DIR = 'checkpoint' |
| 18 | +SAMPLE_DIR = 'samples' |
| 19 | + |
| 20 | +hvd.init() |
| 21 | + |
| 22 | +def maketree(path): |
| 23 | + try: |
| 24 | + os.makedirs(path) |
| 25 | + except: |
| 26 | + pass |
| 27 | + |
| 28 | + |
| 29 | +def load_dataset(enc, path): |
| 30 | + paths = [] |
| 31 | + if os.path.isfile(path): |
| 32 | + # Simple file |
| 33 | + paths.append(path) |
| 34 | + elif os.path.isdir(path): |
| 35 | + # Directory |
| 36 | + for (dirpath, _, fnames) in os.walk(path): |
| 37 | + for fname in fnames: |
| 38 | + paths.append(os.path.join(dirpath, fname)) |
| 39 | + else: |
| 40 | + # Assume glob |
| 41 | + paths = glob.glob(path) |
| 42 | + |
| 43 | + token_chunks = [] |
| 44 | + for path in paths: |
| 45 | + print(str(hvd.local_rank()), 'Reading', path) |
| 46 | + if path.endswith('.npz'): |
| 47 | + # Pre-encoded |
| 48 | + with np.load(path) as npz: |
| 49 | + for item in npz.files: |
| 50 | + token_chunks.append(npz[item]) |
| 51 | + else: |
| 52 | + with open(path, 'r') as fp: |
| 53 | + raw_text = fp.read() |
| 54 | + tokens = np.stack(enc.encode(raw_text)) |
| 55 | + token_chunks.append(tokens) |
| 56 | + return token_chunks |
| 57 | + |
| 58 | + |
| 59 | +def binary_search(f, lo, hi): |
| 60 | + if f(lo) or not f(hi): |
| 61 | + return None |
| 62 | + while hi > lo + 1: |
| 63 | + mid = (lo + hi) // 2 |
| 64 | + if f(mid): |
| 65 | + hi = mid |
| 66 | + else: |
| 67 | + lo = mid |
| 68 | + return hi |
| 69 | + |
| 70 | + |
| 71 | +class Sampler(object): |
| 72 | + """Fairly samples a slice from a set of variable sized chunks. |
| 73 | +
|
| 74 | + 'Fairly' means that the distribution is the same as sampling from one concatenated chunk, |
| 75 | + but without crossing chunk boundaries.""" |
| 76 | + |
| 77 | + def __init__(self, chunks): |
| 78 | + self.chunks = chunks |
| 79 | + self.total_size = sum(chunk.shape[0] for chunk in chunks) |
| 80 | + self.boundaries = [0] |
| 81 | + for i in range(len(chunks)): |
| 82 | + self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0]) |
| 83 | + |
| 84 | + def sample(self, length): |
| 85 | + assert length < self.total_size // len( |
| 86 | + self.chunks |
| 87 | + ), "Dataset files are too small to sample {} tokens at a time".format(length) |
| 88 | + while True: |
| 89 | + index = random.randint(0, self.total_size - length - 1) |
| 90 | + i = binary_search(lambda j: self.boundaries[j] > index, 0, |
| 91 | + len(self.boundaries) - 1) - 1 |
| 92 | + if self.boundaries[i + 1] > index + length: |
| 93 | + within_chunk = index - self.boundaries[i] |
| 94 | + return self.chunks[i][within_chunk:within_chunk + length] |
| 95 | + |
| 96 | + |
| 97 | +def train_main(dataset, |
| 98 | + model_name='117M', |
| 99 | + seed=None, |
| 100 | + batch_size=2, |
| 101 | + sample_length=1023, |
| 102 | + sample_num=1, |
| 103 | + sample_every=4500, |
| 104 | + run_name='run1', |
| 105 | + restore_from='latest', |
| 106 | + save_every=2000): |
| 107 | + |
| 108 | + enc = encoder.get_encoder(model_name) |
| 109 | + hparams = model.default_hparams() |
| 110 | + with open(os.path.join('models', model_name, 'hparams.json')) as f: |
| 111 | + hparams.override_from_dict(json.load(f)) |
| 112 | + |
| 113 | + if sample_length is None: |
| 114 | + sample_length = hparams.n_ctx // 2 |
| 115 | + elif sample_length > hparams.n_ctx: |
| 116 | + raise ValueError( |
| 117 | + "Can't get samples longer than window size: %s" % hparams.n_ctx) |
| 118 | + |
| 119 | + # TF config |
| 120 | + |
| 121 | + config = tf.ConfigProto() |
| 122 | + config.gpu_options.visible_device_list = str(hvd.local_rank()) |
| 123 | + config.gpu_options.allow_growth = True |
| 124 | + |
| 125 | + with tf.Session(config=config) as sess: |
| 126 | + context = tf.placeholder(tf.int32, [batch_size, None]) |
| 127 | + np.random.seed(seed) |
| 128 | + tf.set_random_seed(seed) |
| 129 | + output = model.model(hparams=hparams, X=context) |
| 130 | + loss = tf.reduce_mean( |
| 131 | + tf.nn.sparse_softmax_cross_entropy_with_logits( |
| 132 | + labels=context[:, 1:], logits=output['logits'][:, :-1])) |
| 133 | + |
| 134 | + tf_sample = sample.sample_sequence( |
| 135 | + hparams=hparams, |
| 136 | + length=sample_length, |
| 137 | + context=context, |
| 138 | + batch_size=batch_size, |
| 139 | + temperature=0.8, |
| 140 | + top_k=40) |
| 141 | + |
| 142 | + train_vars = [v for v in tf.trainable_variables() if 'model' in v.name] |
| 143 | + |
| 144 | + opt = tf.train.AdamOptimizer() |
| 145 | + opt = hvd.DistributedOptimizer(opt) |
| 146 | + train_op = opt.minimize(loss, var_list=train_vars) |
| 147 | + |
| 148 | + # Horovod: broadcast initial variable states from rank 0 to all other processes. |
| 149 | + # This is necessary to ensure consistent initialization of all workers when |
| 150 | + # training is started with random weights or restored from a checkpoint. |
| 151 | + bcast = hvd.broadcast_global_variables(0) |
| 152 | + |
| 153 | + saver = tf.train.Saver( |
| 154 | + var_list=train_vars, |
| 155 | + max_to_keep=5, |
| 156 | + keep_checkpoint_every_n_hours=2) |
| 157 | + |
| 158 | + sess.run(tf.global_variables_initializer()) |
| 159 | + |
| 160 | + |
| 161 | + if restore_from == 'latest': |
| 162 | + ckpt = tf.train.latest_checkpoint( |
| 163 | + os.path.join(CHECKPOINT_DIR, run_name)) |
| 164 | + if ckpt is None: |
| 165 | + # Get fresh GPT weights if new run. |
| 166 | + ckpt = tf.train.latest_checkpoint( |
| 167 | + os.path.join('models', model_name)) |
| 168 | + elif restore_from == 'fresh': |
| 169 | + ckpt = tf.train.latest_checkpoint( |
| 170 | + os.path.join('models', model_name)) |
| 171 | + else: |
| 172 | + ckpt = tf.train.latest_checkpoint(restore_from) |
| 173 | + print(str(hvd.local_rank()), 'Loading checkpoint', ckpt) |
| 174 | + saver.restore(sess, ckpt) |
| 175 | + |
| 176 | + bcast.run() |
| 177 | + |
| 178 | + print(str(hvd.local_rank()), 'Loading dataset...') |
| 179 | + chunks = load_dataset(enc, dataset) |
| 180 | + data_sampler = Sampler(chunks) |
| 181 | + print(str(hvd.local_rank()), 'dataset has', data_sampler.total_size, 'tokens') |
| 182 | + print(str(hvd.local_rank()), 'Training...') |
| 183 | + |
| 184 | + counter = 1 |
| 185 | + if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')): |
| 186 | + # Load the step number if we're resuming a run |
| 187 | + # Add 1 so we don't immediately try to save again |
| 188 | + with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), |
| 189 | + 'r') as fp: |
| 190 | + counter = int(fp.read()) + 1 |
| 191 | + |
| 192 | + def save(): |
| 193 | + maketree(os.path.join(CHECKPOINT_DIR, run_name)) |
| 194 | + print( |
| 195 | + 'Saving', |
| 196 | + os.path.join(CHECKPOINT_DIR, run_name, |
| 197 | + 'model-{}').format(counter)) |
| 198 | + saver.save( |
| 199 | + sess, |
| 200 | + os.path.join(CHECKPOINT_DIR, run_name, 'model'), |
| 201 | + global_step=counter) |
| 202 | + with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'), |
| 203 | + 'w') as fp: |
| 204 | + fp.write(str(counter) + '\n') |
| 205 | + |
| 206 | + def generate_samples(): |
| 207 | + context_tokens = data_sampler.sample(1) |
| 208 | + all_text = [] |
| 209 | + index = 0 |
| 210 | + while index < sample_num: |
| 211 | + out = sess.run( |
| 212 | + tf_sample, feed_dict={context: batch_size*[context_tokens]}) |
| 213 | + for i in range(min(sample_num - index, batch_size)): |
| 214 | + text = enc.decode(out[i]) |
| 215 | + text = '======== SAMPLE {} ========\n{}\n'.format(index + 1, text) |
| 216 | + all_text.append(text) |
| 217 | + index += 1 |
| 218 | + print(text) |
| 219 | + maketree(os.path.join(SAMPLE_DIR, run_name)) |
| 220 | + with open( |
| 221 | + os.path.join(SAMPLE_DIR, run_name, |
| 222 | + 'samples-{}').format(counter), 'w') as fp: |
| 223 | + fp.write('\n'.join(all_text)) |
| 224 | + |
| 225 | + avg_loss = (0.0, 0.0) |
| 226 | + start_time = time.time() |
| 227 | + |
| 228 | + try: |
| 229 | + while True: |
| 230 | + |
| 231 | + batch = [data_sampler.sample(1024) for _ in range(batch_size)] |
| 232 | + |
| 233 | + _, lv = sess.run((train_op, loss), feed_dict={context: batch}) |
| 234 | + |
| 235 | + avg_loss = (avg_loss[0] * 0.99 + lv, avg_loss[1] * 0.99 + 1.0) |
| 236 | + |
| 237 | + if hvd.rank() == 0: |
| 238 | + if counter % save_every == 0: |
| 239 | + save() |
| 240 | + if counter % sample_every == 0: |
| 241 | + generate_samples() |
| 242 | + |
| 243 | + print( |
| 244 | + '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' |
| 245 | + .format( |
| 246 | + counter=counter, |
| 247 | + time=time.time() - start_time, |
| 248 | + loss=lv, |
| 249 | + avg=avg_loss[0] / avg_loss[1])) |
| 250 | + |
| 251 | + counter += 1 |
| 252 | + |
| 253 | + except KeyboardInterrupt: |
| 254 | + print('interrupted') |
| 255 | + if hvd.rank() == 0: |
| 256 | + save() |
| 257 | + |
| 258 | + |
| 259 | +if __name__ == '__main__': |
| 260 | + fire.Fire(train_main) |
0 commit comments