-
Notifications
You must be signed in to change notification settings - Fork 182
/
Copy pathbatch_viewer.py
44 lines (40 loc) · 1.23 KB
/
batch_viewer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from mmap_dataset import MMapIndexedDataset
from tqdm import trange
import numpy as np
import argparse
import os
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="",
)
parser.add_argument(
"--start_iteration",
type=int,
default=0,
help="What train step to start logging"
)
parser.add_argument(
"--end_iteration",
type=int,
default=143000,
help="Train step to end logging (inclusive)"
)
parser.add_argument(
"load_path",
type = str,
default = '/mnt/ssd-1/pile_preshuffled/standard/document',
help = ("MMap dataset path with .bin and .idx files. Omit the .bin (or) .idx "
"Extension while specifying the path")
)
parser.add_argument(
"--save_path",
type=str,
default="token_indicies",
help="Save path for files"
)
args = parser.parse_known_args()[0]
os.makedirs(args.save_path, exist_ok=True)
filename = os.path.join(args.save_path, "indicies.npy")
dataset = MMapIndexedDataset(args.load_path, skip_warmup = True)
indicies = dataset[args.start_iteration*1024: args.end_iteration*1024 + 1]
np.save(filename, indicies)