Skip to content

Commit cd257bb

Browse files
authored
Merge pull request openai#15 from minimaxir/cli
Cli
2 parents 15477cf + 636bd79 commit cd257bb

File tree

4 files changed

+185
-13
lines changed

4 files changed

+185
-13
lines changed

README.md

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ print(single_text)
6161

6262
You can pass a `run_name` parameter to `finetune` and `load_gpt2` if you want to store/load multiple models in a `checkpoint` folder.
6363

64+
There is also a command-line interface for both finetining and generation with strong default for just running on a Cloud VM w/ GPU. For finetuning (which will also download the model if not present):
65+
66+
```shell
67+
gpt_2_simple finetune shakespeare.txt
68+
```
69+
70+
And for generation, which generates texts to files in a `gen` folder:
71+
72+
```shell
73+
gpt_2_simple generate
74+
```
75+
76+
Most of the same parameters available in the functions are available as CLI arguments, e.g.:
77+
78+
```shell
79+
gpt_2_simple generate --temperature 1.0 --nsamples 20 --batch_size 20 --length 50 --prefix "<|startoftext|>" --truncate "<|endoftext|>" --include_prefix False --nfiles 5
80+
```
81+
82+
See below to see what some of the CLI arguments do.
83+
6484
NB: *Restart the Python session first* if you want to finetune on another dataset or load another model.
6585

6686
## Differences Between gpt-2-simple And Other Text Generation Utilities
@@ -72,8 +92,9 @@ The method GPT-2 uses to generate text is slightly different than those like oth
7292
* GPT-2 can only generate a maximum of 1024 tokens per request (about 3-4 paragraphs of English text).
7393
* GPT-2 cannot stop early upon reaching a specific end token. (workaround: pass the `truncate` parameter to a `generate` function to only collect text until a specified end token. You may want to reduce `length` appropriately.)
7494
* Higher temperatures work better (e.g. 0.7 - 1.0) to generate more interesting text, while other frameworks work better between 0.2 - 0.5.
75-
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence.
95+
* When finetuning GPT-2, it has no sense of the beginning or end of a document within a larger text. You'll need to use a bespoke character sequence to indicate the beginning and end of a document. Then while generating, you can specify a `prefix` targeting the beginning token sequences, and a `truncate` targeting the end token sequence. You can also set `include_prefix=False` to discard the prefix token while generating (e.g. if it's something unwanted like `<|startoftext|>`).
7696
* GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)!
97+
* Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. If you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80 for only 3x the price, making it compariable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU.
7798

7899
## Planned Work
79100

@@ -86,13 +107,12 @@ Note: this project is intended to have a very tight scope unless demand dictates
86107

87108
## Examples Using gpt-2-simple
88109

89-
* [ResetEra](https://www.resetera.com/threads/i-trained-an-ai-on-thousands-of-resetera-thread-conversations-and-it-created-hot-gaming-shitposts.112167/) — Generated video game forum discussions
110+
* [ResetEra](https://www.resetera.com/threads/i-trained-an-ai-on-thousands-of-resetera-thread-conversations-and-it-created-hot-gaming-shitposts.112167/) — Generated video game forum discussions ([GitHub w/ dumps](https://github.com/minimaxir/resetera-gpt-2))
111+
* [/r/legaladvice](https://www.reddit.com/r/legaladviceofftopic/comments/bfqf22/i_trained_a_moreadvanced_ai_on_rlegaladvice/) — Title generation ([GitHub w/ dumps](https://github.com/minimaxir/legaladvice-gpt2))
90112

91113
## Maintainer/Creator
92114

93-
Max Woolf ([@minimaxir](http://minimaxir.com))
94-
95-
*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.*
115+
Max Woolf ([@minimaxir](https://minimaxir.com))
96116

97117
## License
98118

gpt_2_simple/gpt_2.py

Lines changed: 155 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import sys
55
import shutil
66
import re
7-
from tqdm import tqdm
7+
from tqdm import tqdm, trange
88
import numpy as np
99
import tensorflow as tf
1010
import time
11+
from datetime import datetime
1112
import csv
13+
import argparse
1214

1315
# if in Google Colaboratory
1416
try:
@@ -278,7 +280,8 @@ def generate(sess,
278280
length=1023,
279281
temperature=0.7,
280282
top_k=0,
281-
run_name='run1'):
283+
run_name='run1',
284+
include_prefix=True):
282285
"""Generates text from a model loaded into memory.
283286
284287
Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
@@ -334,8 +337,15 @@ def generate(sess,
334337
if prefix:
335338
gen_text = prefix[0] + gen_text
336339
if truncate:
337-
trunc_text = re.search(r'(.*?)(?:{})'.format(truncate),
338-
gen_text, re.S)
340+
truncate_esc = re.escape(truncate)
341+
if prefix and not include_prefix:
342+
prefix_esc = re.escape(prefix)
343+
pattern = '(?:{})(.*?)(?:{})'.format(prefix_esc,
344+
truncate_esc)
345+
else:
346+
pattern = '(.*?)(?:{})'.format(truncate_esc)
347+
348+
trunc_text = re.search(pattern, gen_text, re.S)
339349
if trunc_text:
340350
gen_text = trunc_text.group(1)
341351
if destination_path:
@@ -363,7 +373,8 @@ def generate_to_file(sess,
363373
length=1023,
364374
temperature=0.7,
365375
top_k=0,
366-
run_name='run1'):
376+
run_name='run1',
377+
include_prefix=True):
367378
"""Generates the texts to a file.
368379
369380
sample_delim separates texts: set to '' if each text is a small document.
@@ -384,7 +395,8 @@ def generate_to_file(sess,
384395
length,
385396
temperature,
386397
top_k,
387-
run_name)
398+
run_name,
399+
include_prefix)
388400

389401

390402
def mount_gdrive():
@@ -452,3 +464,140 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True,
452464
reader = csv.reader(f)
453465
for row in reader:
454466
w.write(start_token + row[0] + end_token + "\n")
467+
468+
469+
def cmd():
470+
"""Function called when invoking from the terminal."""
471+
472+
parser = argparse.ArgumentParser(
473+
description="Easily retrain OpenAI's GPT-2 text-generating model on new texts. (https://github.com/minimaxir/gpt-2-simple)"
474+
)
475+
476+
# Explicit arguments
477+
478+
parser.add_argument(
479+
'--mode', help='Mode for using the CLI (either "finetune" or "generate") [Required]', nargs='?')
480+
parser.add_argument(
481+
'--run_name', help="[finetune/generate] Run number to save/load the model",
482+
nargs='?', default='run1')
483+
parser.add_argument(
484+
'--dataset', help="[finetune] Path to the source text.",
485+
nargs='?', default=None)
486+
parser.add_argument(
487+
'--steps', help="[finetune] Number of steps to train (-1 for infinite)",
488+
nargs='?', default=-1)
489+
parser.add_argument(
490+
'--restore_from', help="[finetune] Whether to load model 'fresh' or from 'latest' checkpoint.",
491+
nargs='?', default='latest')
492+
parser.add_argument(
493+
'--sample_every', help="[finetune] After how many steps to print sample",
494+
nargs='?', default=1000000, type=int)
495+
parser.add_argument(
496+
'--save_every', help="[finetune] After how many steps to save checkpoint",
497+
nargs='?', default=100, type=int)
498+
parser.add_argument(
499+
'--print_every', help="[finetune] After how many steps to print progress",
500+
nargs='?', default=10, type=int)
501+
parser.add_argument(
502+
'--nfiles', help="[generate] How many files to generate.",
503+
nargs='?', default=1, type=int)
504+
parser.add_argument(
505+
'--nsamples', help="[generate] How many texts to generate.",
506+
nargs='?', default=1, type=int)
507+
parser.add_argument(
508+
'--folder', help="[generate] Folder to save the generated files",
509+
nargs='?', default="gen", type=str)
510+
parser.add_argument(
511+
'--length', help="[generate] Length (tokens) of the generated texts",
512+
nargs='?', default=1023, type=int)
513+
parser.add_argument(
514+
'--temperature', help="[generate] Temperature of the generated texts",
515+
nargs='?', default=0.7, type=float)
516+
parser.add_argument(
517+
'--batch_size', help="[generate] Batch size for generation (increase for GPUs)",
518+
nargs='?', default=1, type=int)
519+
parser.add_argument(
520+
'--prefix', help="[generate] Prefix for generated texts",
521+
nargs='?', default=None)
522+
parser.add_argument(
523+
'--truncate', help="[generate] Truncation for generated texts",
524+
nargs='?', default=None)
525+
# https://stackoverflow.com/a/46951029
526+
parser.add_argument(
527+
'--include_prefix', help="[generate] Include prefix when truncating.",
528+
nargs='?', default=True, type=lambda x: (str(x).lower() == 'true'))
529+
parser.add_argument(
530+
'--sample_delim', help="[generate] Delimiter between each generated sample.",
531+
nargs='?', default='=' * 20 + '\n', type=str)
532+
533+
# Positional arguments
534+
parser.add_argument('mode', nargs='?')
535+
parser.add_argument('dataset', nargs='?')
536+
537+
args = parser.parse_args()
538+
assert args.mode in ['finetune', 'generate'], "Mode must be 'finetune' or 'generate'"
539+
540+
if args.mode == 'finetune':
541+
assert args.dataset is not None, "You need to provide a dataset."
542+
543+
cmd_finetune(dataset=args.dataset, run_name=args.run_name,
544+
steps=args.steps, restore_from=args.restore_from,
545+
sample_every=args.sample_every,
546+
save_every=args.save_every,
547+
print_every=args.print_every)
548+
if args.mode == "generate":
549+
cmd_generate(nfiles=args.nfiles, nsamples=args.nsamples,
550+
folder=args.folder, length=args.length,
551+
temperature=args.temperature, batch_size=args.batch_size,
552+
prefix=args.prefix, truncate=args.truncate,
553+
include_prefix=args.include_prefix,
554+
sample_delim=args.sample_delim)
555+
556+
557+
def cmd_finetune(dataset, run_name, steps, restore_from, sample_every,
558+
save_every, print_every):
559+
"""Wrapper script for finetuning the model via the CLI."""
560+
561+
if not is_gpt2_downloaded():
562+
download_gpt2()
563+
564+
sess = start_tf_sess()
565+
finetune(sess, dataset=dataset, run_name=run_name,
566+
steps=steps, restore_from=restore_from,
567+
sample_every=sample_every, save_every=save_every,
568+
print_every=print_every)
569+
570+
571+
def cmd_generate(nfiles, nsamples, folder,
572+
length, temperature, batch_size,
573+
prefix, truncate, include_prefix,
574+
sample_delim):
575+
"""Wrapper script for generating text via the CLI.
576+
The files are generated into a folder, which can be downloaded
577+
recursively by downloading the entire folder.
578+
"""
579+
580+
sess = start_tf_sess()
581+
load_gpt2(sess)
582+
583+
try:
584+
os.mkdir(folder)
585+
except:
586+
shutil.rmtree(folder)
587+
os.mkdir(folder)
588+
589+
for _ in trange(nfiles):
590+
gen_file = os.path.join(folder,
591+
'gpt2_gentext_{:%Y%m%d_%H%M%S}.txt'.format(datetime.utcnow()))
592+
593+
generate_to_file(sess,
594+
destination_path=gen_file,
595+
length=length,
596+
temperature=temperature,
597+
nsamples=nsamples,
598+
batch_size=batch_size,
599+
prefix=prefix,
600+
truncate=truncate,
601+
include_prefix=include_prefix,
602+
sample_delim=sample_delim
603+
)

gpt_2_simple/src/load_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def load_dataset(enc, path, combine):
3030
token_chunks.append(npz[item])
3131
else:
3232
# Plain text
33-
with open(path, 'r') as fp:
33+
with open(path, 'r', encoding='utf8', errors='ignore') as fp:
3434
raw_text += fp.read()
3535
if len(raw_text) >= combine:
3636
tokens = np.stack(enc.encode(raw_text))

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
setup(
4848
name='gpt_2_simple',
4949
packages=['gpt_2_simple'], # this must be the same as the name above
50-
version='0.2',
50+
version='0.3',
5151
description="Python package to easily retrain OpenAI's GPT-2 " \
5252
"text-generating model on new texts.",
5353
long_description=long_description,
@@ -58,6 +58,9 @@
5858
keywords=['deep learning', 'tensorflow', 'text generation'],
5959
classifiers=[],
6060
license='MIT',
61+
entry_points={
62+
'console_scripts': ['gpt_2_simple=gpt_2_simple.gpt_2:cmd'],
63+
},
6164
python_requires='>=3.5',
6265
include_package_data=True,
6366
install_requires=['regex', 'requests', 'tqdm', 'numpy']

0 commit comments

Comments
 (0)