Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions data/cached_challenge_fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
DATASETS_DIR = ROOT / "datasets"
TOKENIZERS_DIR = ROOT / "tokenizers"


def prefixed_remote_path(relative_path: str) -> str:
relative_path = relative_path.lstrip("/")
prefix = REMOTE_ROOT_PREFIX.strip("/")
return f"{prefix}/{relative_path}" if prefix else relative_path


def remote_root_parts() -> tuple[str, ...]:
prefix = REMOTE_ROOT_PREFIX.strip("/")
return Path(prefix).parts if prefix else ()

def dataset_dir_for_variant(name: str) -> str:
if name == "byte260":
return "fineweb10B_byte260"
Expand All @@ -22,9 +33,10 @@ def dataset_dir_for_variant(name: str) -> str:


def local_path_for_remote(relative_path: str) -> Path:
remote_path = Path(relative_path)
if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,):
remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX)
remote_path = Path(relative_path.lstrip("/"))
prefix_parts = remote_root_parts()
if prefix_parts and remote_path.parts[: len(prefix_parts)] == prefix_parts:
remote_path = Path(*remote_path.parts[len(prefix_parts) :])
if remote_path.parts[:1] == ("datasets",):
return DATASETS_DIR.joinpath(*remote_path.parts[1:])
if remote_path.parts[:1] == ("tokenizers",):
Expand Down Expand Up @@ -59,7 +71,7 @@ def get(relative_path: str) -> None:


def manifest_path() -> Path:
return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json")
return local_path_for_remote(prefixed_remote_path("manifest.json"))


def load_manifest(*, skip_manifest_download: bool) -> dict:
Expand All @@ -69,7 +81,7 @@ def load_manifest(*, skip_manifest_download: bool) -> dict:
raise FileNotFoundError(
f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}"
)
get(f"{REMOTE_ROOT_PREFIX}/manifest.json")
get(prefixed_remote_path("manifest.json"))
return json.loads(path.read_text(encoding="utf-8"))


Expand Down Expand Up @@ -140,17 +152,17 @@ def main() -> None:
raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json")

if args.with_docs:
get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl")
get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json")
get(prefixed_remote_path("docs_selected.jsonl"))
get(prefixed_remote_path("docs_selected.source_manifest.json"))

dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}"
dataset_prefix = prefixed_remote_path(f"datasets/{dataset_dir}")
for i in range(val_shards):
get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin")
for i in range(train_shards):
get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin")

for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry):
get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}")
get(prefixed_remote_path(artifact_path))


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions tests/test_cached_challenge_fineweb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import importlib.util
import sys
import types
import unittest
from pathlib import Path


MODULE_PATH = Path(__file__).resolve().parents[1] / "data" / "cached_challenge_fineweb.py"


def load_module():
fake_hf = types.ModuleType("huggingface_hub")
fake_hf.hf_hub_download = lambda *args, **kwargs: None

previous = sys.modules.get("huggingface_hub")
sys.modules["huggingface_hub"] = fake_hf
try:
spec = importlib.util.spec_from_file_location("cached_challenge_fineweb_under_test", MODULE_PATH)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
finally:
if previous is None:
sys.modules.pop("huggingface_hub", None)
else:
sys.modules["huggingface_hub"] = previous


class CachedChallengeFineWebPathTests(unittest.TestCase):
def test_local_path_for_remote_strips_multi_segment_prefix(self):
module = load_module()
module.REMOTE_ROOT_PREFIX = "exports/v1"

path = module.local_path_for_remote(
"exports/v1/datasets/fineweb10B_sp1024/fineweb_train_000000.bin"
)

self.assertEqual(
path,
module.DATASETS_DIR / "fineweb10B_sp1024" / "fineweb_train_000000.bin",
)

def test_manifest_path_uses_repo_root_when_remote_prefix_is_empty(self):
module = load_module()
module.REMOTE_ROOT_PREFIX = ""

self.assertEqual(module.manifest_path(), module.ROOT / "manifest.json")


if __name__ == "__main__":
unittest.main()