diff --git a/optimum/gptq/data.py b/optimum/gptq/data.py index 127e6676cd..baf6ee2397 100644 --- a/optimum/gptq/data.py +++ b/optimum/gptq/data.py @@ -125,17 +125,19 @@ def get_wikitext2(tokenizer: Any, seqlen: int, nsamples: int, split: str = "trai data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") elif split == "validation": data = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") - # length of 288059 should be enough - text = "".join([" \n" if s == "" else s for s in data["text"][:1000]]) - - enc = tokenizer(text, return_tensors="pt") dataset = [] for _ in range(nsamples): + while True: + i = random.randint(0, len(data) - 1) + enc = tokenizer(data[i]["text"], return_tensors="pt") + if enc.input_ids.shape[1] >= seqlen: + break i = random.randint(0, enc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = enc.input_ids[:, i:j] attention_mask = torch.ones_like(inp) dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + return dataset