diff --git a/data/cached_challenge_fineweb.py b/data/cached_challenge_fineweb.py index fa8029be42..3c13b399e8 100644 --- a/data/cached_challenge_fineweb.py +++ b/data/cached_challenge_fineweb.py @@ -13,6 +13,17 @@ DATASETS_DIR = ROOT / "datasets" TOKENIZERS_DIR = ROOT / "tokenizers" + +def prefixed_remote_path(relative_path: str) -> str: + relative_path = relative_path.lstrip("/") + prefix = REMOTE_ROOT_PREFIX.strip("/") + return f"{prefix}/{relative_path}" if prefix else relative_path + + +def remote_root_parts() -> tuple[str, ...]: + prefix = REMOTE_ROOT_PREFIX.strip("/") + return Path(prefix).parts if prefix else () + def dataset_dir_for_variant(name: str) -> str: if name == "byte260": return "fineweb10B_byte260" @@ -22,9 +33,10 @@ def dataset_dir_for_variant(name: str) -> str: def local_path_for_remote(relative_path: str) -> Path: - remote_path = Path(relative_path) - if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,): - remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX) + remote_path = Path(relative_path.lstrip("/")) + prefix_parts = remote_root_parts() + if prefix_parts and remote_path.parts[: len(prefix_parts)] == prefix_parts: + remote_path = Path(*remote_path.parts[len(prefix_parts) :]) if remote_path.parts[:1] == ("datasets",): return DATASETS_DIR.joinpath(*remote_path.parts[1:]) if remote_path.parts[:1] == ("tokenizers",): @@ -59,7 +71,7 @@ def get(relative_path: str) -> None: def manifest_path() -> Path: - return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json") + return local_path_for_remote(prefixed_remote_path("manifest.json")) def load_manifest(*, skip_manifest_download: bool) -> dict: @@ -69,7 +81,7 @@ def load_manifest(*, skip_manifest_download: bool) -> dict: raise FileNotFoundError( f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}" ) - get(f"{REMOTE_ROOT_PREFIX}/manifest.json") + get(prefixed_remote_path("manifest.json")) return json.loads(path.read_text(encoding="utf-8")) @@ -140,17 +152,17 @@ def main() -> None: raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json") if args.with_docs: - get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl") - get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json") + get(prefixed_remote_path("docs_selected.jsonl")) + get(prefixed_remote_path("docs_selected.source_manifest.json")) - dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}" + dataset_prefix = prefixed_remote_path(f"datasets/{dataset_dir}") for i in range(val_shards): get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin") for i in range(train_shards): get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin") for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry): - get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}") + get(prefixed_remote_path(artifact_path)) if __name__ == "__main__": diff --git a/tests/test_cached_challenge_fineweb.py b/tests/test_cached_challenge_fineweb.py new file mode 100644 index 0000000000..e7a31a7690 --- /dev/null +++ b/tests/test_cached_challenge_fineweb.py @@ -0,0 +1,52 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + + +MODULE_PATH = Path(__file__).resolve().parents[1] / "data" / "cached_challenge_fineweb.py" + + +def load_module(): + fake_hf = types.ModuleType("huggingface_hub") + fake_hf.hf_hub_download = lambda *args, **kwargs: None + + previous = sys.modules.get("huggingface_hub") + sys.modules["huggingface_hub"] = fake_hf + try: + spec = importlib.util.spec_from_file_location("cached_challenge_fineweb_under_test", MODULE_PATH) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + finally: + if previous is None: + sys.modules.pop("huggingface_hub", None) + else: + sys.modules["huggingface_hub"] = previous + + +class CachedChallengeFineWebPathTests(unittest.TestCase): + def test_local_path_for_remote_strips_multi_segment_prefix(self): + module = load_module() + module.REMOTE_ROOT_PREFIX = "exports/v1" + + path = module.local_path_for_remote( + "exports/v1/datasets/fineweb10B_sp1024/fineweb_train_000000.bin" + ) + + self.assertEqual( + path, + module.DATASETS_DIR / "fineweb10B_sp1024" / "fineweb_train_000000.bin", + ) + + def test_manifest_path_uses_repo_root_when_remote_prefix_is_empty(self): + module = load_module() + module.REMOTE_ROOT_PREFIX = "" + + self.assertEqual(module.manifest_path(), module.ROOT / "manifest.json") + + +if __name__ == "__main__": + unittest.main()