Skip to content

Commit 1792715

Browse files
author
Fahad Alghanim
committed
Fix unstable tokenizer fingerprinting
Tokenizers backed by `tokenizers` can mutate truncation/padding state when called, which made dataset transform fingerprints unstable and prevented `.map(load_from_cache_file=True)` from reusing cached results. This change makes tokenizer hashing stable by temporarily clearing backend truncation/padding during serialization for fingerprinting, then restoring it. Add a regression test and a simple benchmark to demonstrate cache-hit speedups. Fixes #3847
1 parent 224b4e6 commit 1792715

File tree

3 files changed

+133
-3
lines changed

3 files changed

+133
-3
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import json
2+
import os
3+
import tempfile
4+
5+
from tokenizers import Tokenizer
6+
from tokenizers.models import WordLevel
7+
from tokenizers.pre_tokenizers import Whitespace
8+
from transformers import PreTrainedTokenizerFast
9+
10+
import datasets
11+
from utils import get_duration
12+
13+
14+
RESULTS_BASEPATH, RESULTS_FILENAME = os.path.split(__file__)
15+
RESULTS_FILE_PATH = os.path.join(RESULTS_BASEPATH, "results", RESULTS_FILENAME.replace(".py", ".json"))
16+
17+
18+
def _make_tokenizer() -> PreTrainedTokenizerFast:
19+
vocab = {"[UNK]": 0, "[PAD]": 1, "hello": 2, "world": 3}
20+
backend = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
21+
backend.pre_tokenizer = Whitespace()
22+
return PreTrainedTokenizerFast(tokenizer_object=backend, unk_token="[UNK]", pad_token="[PAD]")
23+
24+
25+
@get_duration
26+
def map_once(dataset: datasets.Dataset, tok: PreTrainedTokenizerFast):
27+
def tokenize(examples):
28+
return tok(examples["text"], truncation=True, padding="max_length", max_length=8)
29+
30+
_ = dataset.map(tokenize, batched=True, load_from_cache_file=True, remove_columns=["text"])
31+
32+
33+
def benchmark_map_cache_reuse():
34+
times = {}
35+
tok = _make_tokenizer()
36+
37+
with tempfile.TemporaryDirectory() as tmp_dir:
38+
raw = datasets.Dataset.from_dict({"text": ["hello world"] * 200_000})
39+
stored = os.path.join(tmp_dir, "stored")
40+
raw.save_to_disk(stored)
41+
dataset = datasets.Dataset.load_from_disk(stored)
42+
43+
# First run: cache miss (writes cache file)
44+
times["map tokenize (cache miss)"] = map_once(dataset, tok)
45+
# Second run: cache hit (should be much faster if fingerprint is stable)
46+
times["map tokenize (cache hit)"] = map_once(dataset, tok)
47+
48+
with open(RESULTS_FILE_PATH, "wb") as f:
49+
f.write(json.dumps(times).encode("utf-8"))
50+
51+
52+
if __name__ == "__main__":
53+
benchmark_map_cache_reuse()

src/datasets/utils/_dill.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,42 @@ def create_spacyLanguage(config, bytes):
212212

213213
def _save_transformersPreTrainedTokenizerBase(pickler, obj):
214214
log(pickler, f"Tok: {obj}")
215-
# Ignore the `cache` attribute
216-
state = obj.__dict__
215+
# Ignore the `cache` attribute and make hashing stable.
216+
#
217+
# Some tokenizers backed by the `tokenizers` library mutate their internal `_tokenizer` state when called
218+
# (e.g. by enabling truncation/padding). This can change the serialized bytes across runs and make dataset
219+
# fingerprints unstable, which prevents `.map(load_from_cache_file=True)` from reusing cache files.
220+
#
221+
# For hashing/fingerprinting, we temporarily disable backend truncation/padding to avoid these runtime settings
222+
# affecting the fingerprint, then restore the original settings.
223+
state = obj.__dict__.copy()
217224
if "cache" in state and isinstance(state["cache"], dict):
218225
state["cache"] = {}
219-
pickler.save_reduce(type(obj), (), state=state, obj=obj)
226+
227+
backend_tokenizer = obj.__dict__.get("_tokenizer")
228+
truncation = padding = None
229+
if backend_tokenizer is not None and hasattr(backend_tokenizer, "truncation") and hasattr(backend_tokenizer, "padding"):
230+
truncation = backend_tokenizer.truncation
231+
padding = backend_tokenizer.padding
232+
try:
233+
if truncation is not None and hasattr(backend_tokenizer, "no_truncation"):
234+
backend_tokenizer.no_truncation()
235+
if padding is not None and hasattr(backend_tokenizer, "no_padding"):
236+
backend_tokenizer.no_padding()
237+
except Exception:
238+
truncation = padding = None
239+
240+
try:
241+
pickler.save_reduce(type(obj), (), state=state, obj=obj)
242+
finally:
243+
try:
244+
if backend_tokenizer is not None:
245+
if truncation is not None and hasattr(backend_tokenizer, "enable_truncation"):
246+
backend_tokenizer.enable_truncation(**truncation)
247+
if padding is not None and hasattr(backend_tokenizer, "enable_padding"):
248+
backend_tokenizer.enable_padding(**padding)
249+
except Exception:
250+
pass
220251
log(pickler, "# Tok")
221252

222253

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
from tokenizers import Tokenizer
3+
from tokenizers.models import WordLevel
4+
from tokenizers.pre_tokenizers import Whitespace
5+
from transformers import PreTrainedTokenizerFast
6+
7+
from datasets import Dataset
8+
from datasets.fingerprint import Hasher
9+
10+
11+
def _make_mutable_backend_tokenizer() -> PreTrainedTokenizerFast:
12+
# Build a tiny tokenizer entirely locally (no network), backed by `tokenizers.Tokenizer`.
13+
vocab = {"[UNK]": 0, "[PAD]": 1, "hello": 2, "world": 3}
14+
backend = Tokenizer(WordLevel(vocab=vocab, unk_token="[UNK]"))
15+
backend.pre_tokenizer = Whitespace()
16+
return PreTrainedTokenizerFast(tokenizer_object=backend, unk_token="[UNK]", pad_token="[PAD]")
17+
18+
19+
def test_hasher_hash_tokenizer_stable_after_call():
20+
tok = _make_mutable_backend_tokenizer()
21+
h0 = Hasher.hash(tok)
22+
_ = tok(["hello world"], truncation=True, padding="max_length", max_length=8)
23+
h1 = Hasher.hash(tok)
24+
assert h0 == h1
25+
26+
27+
def test_map_cache_reused_with_tokenizer_after_call(tmp_path):
28+
# Regression test for https://github.com/huggingface/datasets/issues/3847
29+
#
30+
# Tokenizers can mutate backend truncation/padding state when called, which used to make the
31+
# dataset transform fingerprint unstable and prevented cache reuse.
32+
tok = _make_mutable_backend_tokenizer()
33+
34+
raw = Dataset.from_dict({"text": ["hello world"] * 1000})
35+
stored = tmp_path / "stored"
36+
raw.save_to_disk(stored)
37+
raw = Dataset.load_from_disk(stored)
38+
39+
def tokenize(examples):
40+
return tok(examples["text"], truncation=True, padding="max_length", max_length=8)
41+
42+
res1 = raw.map(tokenize, batched=True, load_from_cache_file=True, remove_columns=["text"])
43+
res2 = raw.map(tokenize, batched=True, load_from_cache_file=True, remove_columns=["text"])
44+
45+
assert res1.cache_files and res2.cache_files
46+
assert res1.cache_files[0]["filename"] == res2.cache_files[0]["filename"]

0 commit comments

Comments
 (0)