22from tokenizer import ExLlamaTokenizer
33from generator import ExLlamaGenerator
44import torch
5+ import torch .nn .functional as F
56import os , glob
67import cuda_ext
78
2021
2122config = ExLlamaConfig (model_config_path ) # create config from config.json
2223config .model_path = model_path # supply path to model weights file
23- config .max_input_len = 16
2424
2525model = ExLlama (config ) # create ExLlama instance and load the weights
2626tokenizer = ExLlamaTokenizer (tokenizer_path ) # create tokenizer from tokenizer model file
3131# Configure generator
3232
3333generator .settings .token_repetition_penalty_max = 1.15
34- generator .settings .temperature = 0.75
34+ generator .settings .temperature = 0.95
3535generator .settings .top_k = 40
36- generator .settings .top_p = 0.65
37- # generator.settings.typical = 0.5
36+ generator .settings .top_p = 0.75
37+ # generator.settings.typical = 0.95
3838
3939# Prompts to mix
4040
4646
4747f2 = \
4848"""[INST] <<SYS>>
49- You are a rude and obnoxious assistant. You hate everything and everyone.
5049<</SYS>>
50+ You are a rude and obnoxious assistant. You hate everything and everyone.
5151{prompt}[/INST]"""
5252
53+
5354prompts = \
5455[
5556 f1 .replace ("{prompt}" , "Tell me about Homer Simpson" ),
5657 f2 .replace ("{prompt}" , "Tell me about Homer Simpson" ),
5758]
5859
59- def mixed_generation (prompts , alpha , max_new_tokens ):
60+ def generate_cfg (prompts , alpha , max_new_tokens ):
6061
6162 ids , mask = tokenizer .encode (prompts , return_mask = True )
6263 generator .gen_begin (ids , mask = mask )
6364
6465 # Sampling loop
6566
66- for i in range (max_new_tokens ):
67+ for _ in range (max_new_tokens ):
6768
6869 logits = model .forward (generator .sequence [:, - 1 :], cache , input_mask = mask )
6970 generator .apply_rep_penalty (logits )
7071
72+ logits = F .log_softmax (logits , dim = - 1 )
7173 logits_mixed = (1 - alpha ) * logits [0 ] + alpha * logits [1 ]
7274
7375 sampled_token , _ = generator .sample_current (logits_mixed )
@@ -86,5 +88,5 @@ def mixed_generation(prompts, alpha, max_new_tokens):
8688 print (f"--------------------------------------" )
8789 print (f"alpha = { alpha :.1f} " )
8890 print (f"--------------------------------------" )
89- output = mixed_generation (prompts , alpha , 200 )
91+ output = generate_cfg (prompts , alpha , 200 )
9092 print (output [len (prompts [0 ]):].strip ())
0 commit comments