Skip to content

Commit 47cc4da

Browse files
authored
Changing the test model in Quanto kv cache (#36670)
changing model
1 parent bc3d578 commit 47cc4da

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

tests/quantization/quanto_integration/test_quanto.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,17 +448,19 @@ class QuantoKVCacheQuantizationTest(unittest.TestCase):
448448
@require_read_token
449449
def test_quantized_cache(self):
450450
EXPECTED_TEXT_COMPLETION = [
451-
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory is the most",
452-
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
451+
"Simply put, the theory of relativity states that 1) time and space are not absolute, but are relative to the observer, and 2) the laws of physics are the same everywhere in the universe. This means that the speed of light is",
452+
"My favorite all time favorite condiment is ketchup. I love how it adds a sweet and tangy flavor to my food. I also enjoy using it as a dip for fries, burgers, and grilled meats. It's a classic condiment that never",
453453
]
454454

455455
prompts = [
456456
"Simply put, the theory of relativity states that ",
457457
"My favorite all time favorite condiment is ketchup.",
458458
]
459-
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
459+
tokenizer = LlamaTokenizer.from_pretrained(
460+
"unsloth/Llama-3.2-1B-Instruct", pad_token="</s>", padding_side="left"
461+
)
460462
model = LlamaForCausalLM.from_pretrained(
461-
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
463+
"unsloth/Llama-3.2-1B-Instruct", device_map="sequential", torch_dtype=torch.float16
462464
)
463465
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
464466

0 commit comments

Comments
 (0)