forked from openai/parameter-golf
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcached_challenge_fineweb.py
More file actions
169 lines (140 loc) · 6.11 KB
/
cached_challenge_fineweb.py
File metadata and controls
169 lines (140 loc) · 6.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import argparse
import json
import os
import shutil
from pathlib import Path
from huggingface_hub import hf_hub_download
REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf")
REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets")
ROOT = Path(__file__).resolve().parent
DATASETS_DIR = ROOT / "datasets"
TOKENIZERS_DIR = ROOT / "tokenizers"
def prefixed_remote_path(relative_path: str) -> str:
relative_path = relative_path.lstrip("/")
prefix = REMOTE_ROOT_PREFIX.strip("/")
return f"{prefix}/{relative_path}" if prefix else relative_path
def remote_root_parts() -> tuple[str, ...]:
prefix = REMOTE_ROOT_PREFIX.strip("/")
return Path(prefix).parts if prefix else ()
def dataset_dir_for_variant(name: str) -> str:
if name == "byte260":
return "fineweb10B_byte260"
if name.startswith("sp") and name[2:].isdigit():
return f"fineweb10B_{name}"
raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp<VOCAB_SIZE>")
def local_path_for_remote(relative_path: str) -> Path:
remote_path = Path(relative_path.lstrip("/"))
prefix_parts = remote_root_parts()
if prefix_parts and remote_path.parts[: len(prefix_parts)] == prefix_parts:
remote_path = Path(*remote_path.parts[len(prefix_parts) :])
if remote_path.parts[:1] == ("datasets",):
return DATASETS_DIR.joinpath(*remote_path.parts[1:])
if remote_path.parts[:1] == ("tokenizers",):
return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:])
return ROOT / remote_path
def get(relative_path: str) -> None:
destination = local_path_for_remote(relative_path)
if destination.exists():
return
if destination.is_symlink():
destination.unlink()
remote_path = Path(relative_path)
cached_path = Path(
hf_hub_download(
repo_id=REPO_ID,
filename=remote_path.name,
subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None,
repo_type="dataset",
)
)
# HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we
# always materialize a real file in data/, not a broken relative symlink.
cached_source = cached_path.resolve(strict=True)
destination.parent.mkdir(parents=True, exist_ok=True)
try:
os.link(cached_source, destination)
except OSError:
shutil.copy2(cached_source, destination)
def manifest_path() -> Path:
return local_path_for_remote(prefixed_remote_path("manifest.json"))
def load_manifest(*, skip_manifest_download: bool) -> dict:
path = manifest_path()
if not path.is_file():
if skip_manifest_download:
raise FileNotFoundError(
f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}"
)
get(prefixed_remote_path("manifest.json"))
return json.loads(path.read_text(encoding="utf-8"))
def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]:
artifacts = []
for key in ("model_path", "vocab_path", "path"):
value = tokenizer_entry.get(key)
if value:
artifacts.append(str(value))
if not artifacts:
raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}")
return artifacts
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face")
parser.add_argument(
"train_shards_positional",
nargs="?",
type=int,
default=None,
help=argparse.SUPPRESS,
)
parser.add_argument(
"--train-shards",
type=int,
default=80,
help="Number of training shards to download for the selected variant. Defaults to 80.",
)
parser.add_argument(
"--variant",
default="sp1024",
help="Tokenizer family to download, for example sp1024, sp4096, or byte260.",
)
parser.add_argument(
"--skip-manifest",
action="store_true",
help="Skip downloading manifest.json.",
)
parser.add_argument(
"--with-docs",
action="store_true",
help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.",
)
return parser
def main() -> None:
args = build_parser().parse_args()
dataset_dir = dataset_dir_for_variant(args.variant)
train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards
if train_shards < 0:
raise ValueError("train_shards must be non-negative")
manifest = load_manifest(skip_manifest_download=args.skip_manifest)
dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None)
if dataset_entry is None:
raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json")
max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train"))
val_shards = int((dataset_entry.get("stats") or {}).get("files_val"))
if train_shards > max_train_shards:
raise ValueError(
f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}"
)
tokenizer_name = dataset_entry.get("tokenizer_name")
tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None)
if tokenizer_entry is None:
raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json")
if args.with_docs:
get(prefixed_remote_path("docs_selected.jsonl"))
get(prefixed_remote_path("docs_selected.source_manifest.json"))
dataset_prefix = prefixed_remote_path(f"datasets/{dataset_dir}")
for i in range(val_shards):
get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin")
for i in range(train_shards):
get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin")
for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry):
get(prefixed_remote_path(artifact_path))
if __name__ == "__main__":
main()