Skip to content

Commit fef4f25

Browse files
committed
remove Tail-Free sampling, ggml-org/llama.cpp#10071
more top_n_sigma、xtc_threshold: float = 0.1、xtc_probability: float params
1 parent d984742 commit fef4f25

File tree

5 files changed

+93
-57
lines changed

5 files changed

+93
-57
lines changed

Diff for: examples/low_level_api/common.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class GptParams:
2121
ignore_eos: bool = False
2222
logit_bias: dict[int, float] = field(default_factory=dict)
2323
top_k: int = 40
24+
top_n_sigma: float = -1.00
2425
top_p: float = 0.95
25-
tfs_z: float = 1.00
26+
2627
typical_p: float = 1.00
2728
temp: float = 0.80
2829
repeat_penalty: float = 1.10
@@ -32,7 +33,8 @@ class GptParams:
3233
mirostat: int = 0
3334
mirostat_tau: float = 5.0
3435
mirostat_eta: float = 0.1
35-
36+
xtc_threshold: float = 0.1
37+
xtc_probability: float = 0.0
3638
model: str = "./models/llama-7B/ggml-model.bin"
3739
prompt: str = ""
3840
path_session: str = ""
@@ -147,14 +149,10 @@ def gpt_params_parse(argv=None):
147149
"--top_k", type=int, default=40, help="top-k sampling", dest="top_k"
148150
)
149151
parser.add_argument(
150-
"--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
152+
"--top_n_sigma", type=int, default=40, help="top-n-sigma sampling", dest="top_n_sigma"
151153
)
152154
parser.add_argument(
153-
"--tfs",
154-
type=float,
155-
default=1.0,
156-
help="tail free sampling, parameter z (1.0 = disabled)",
157-
dest="tfs_z",
155+
"--top_p", type=float, default=0.95, help="top-p samplin", dest="top_p"
158156
)
159157
parser.add_argument(
160158
"--temp", type=float, default=0.80, help="temperature", dest="temp"
@@ -178,7 +176,7 @@ def gpt_params_parse(argv=None):
178176
type=float,
179177
default=0.0,
180178
help="repeat alpha frequency penalty (0.0 = disabled)",
181-
dest="tfs_z",
179+
dest="frequency_penalty",
182180
)
183181
parser.add_argument(
184182
"--presence_penalty",
@@ -209,6 +207,22 @@ def gpt_params_parse(argv=None):
209207
dest="mirostat_eta",
210208
)
211209

210+
parser.add_argument(
211+
"--xtc_threshold",
212+
type=float,
213+
default=0.1,
214+
help=" Sets a minimum probability threshold for tokens to be removed (default: 0.1)",
215+
dest="xtc_threshold",
216+
)
217+
218+
parser.add_argument(
219+
"--xtc_probability",
220+
type=float,
221+
default=0.0,
222+
help="ets the chance for token removal (checked once on sampler start) (default: 0.0)",
223+
dest="xtc_probability",
224+
)
225+
212226
parser.add_argument(
213227
"-m",
214228
"--model",

Diff for: examples/low_level_api/low_level_api_chat_cpp.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,17 @@ def __init__(self, params: GptParams) -> None:
275275
presence_penalty = {self.params.presence_penalty},\
276276
frequency_penalty = {self.params.frequency_penalty},\
277277
top_k = {self.params.top_k},\
278-
tfs_z = {self.params.tfs_z},\
278+
top_n_sigma = {self.params.top_n_sigma},\
279279
top_p = {self.params.top_p},\
280280
typical_p = {self.params.typical_p},\
281281
temp = {self.params.temp},\
282282
mirostat = {self.params.mirostat},\
283283
mirostat_lr = {self.params.mirostat_eta},\
284284
mirostat_ent = {self.params.mirostat_tau},\
285285
286+
xtc_threshold = {self.params.xtc_threshold},\
287+
xtc_probability = {self.params.xtc_probability},\
288+
286289
generate: n_ctx = {self.n_ctx},\
287290
n_batch = {self.params.n_batch},\
288291
n_predict = {self.params.n_predict},\
@@ -454,7 +457,7 @@ def generate(self):
454457
_arr = (llama_cpp.llama_token * last_n_repeat)(
455458
*self.last_n_tokens[len(self.last_n_tokens) - last_n_repeat :]
456459
)
457-
llama_cpp.llama_sample_repetition_penalties(
460+
llama_cpp.llama_sampler_init_penalties(
458461
ctx=self.ctx,
459462
candidates=candidates_p,
460463
last_tokens_data=_arr,
@@ -474,15 +477,15 @@ def generate(self):
474477

475478
if self.params.temp <= 0:
476479
# Greedy sampling
477-
id = llama_cpp.llama_sample_token_greedy(self.ctx, candidates_p)
480+
id = llama_cpp.llama_sampler_init_greedy(self.ctx, candidates_p)
478481
else:
479482
if self.params.mirostat == 1:
480483
mirostat_mu = 2.0 * self.params.mirostat_tau
481484
mirostat_m = 100
482-
llama_cpp.llama_sample_temperature(
485+
llama_cpp.llama_sampler_init_temp(
483486
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
484487
)
485-
id = llama_cpp.llama_sample_token_mirostat(
488+
id = llama_cpp.llama_sampler_init_mirostat(
486489
self.ctx,
487490
candidates_p,
488491
llama_cpp.c_float(self.params.mirostat_tau),
@@ -495,7 +498,7 @@ def generate(self):
495498
llama_cpp.llama_sample_temperature(
496499
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
497500
)
498-
id = llama_cpp.llama_sample_token_mirostat_v2(
501+
id = llama_cpp.llama_sampler_init_mirostat_v2(
499502
self.ctx,
500503
candidates_p,
501504
llama_cpp.c_float(self.params.mirostat_tau),
@@ -504,31 +507,31 @@ def generate(self):
504507
)
505508
else:
506509
# Temperature sampling
507-
llama_cpp.llama_sample_top_k(
510+
llama_cpp.llama_sampler_init_top_k(
508511
self.ctx,
509512
candidates_p,
510513
top_k,
511514
min_keep=llama_cpp.c_size_t(1),
512515
)
513-
llama_cpp.llama_sample_tail_free(
516+
llama_cpp.llama_sampler_init_top_n_sigma(
514517
self.ctx,
515518
candidates_p,
516-
llama_cpp.c_float(self.params.tfs_z),
519+
llama_cpp.c_float(self.params.top_n_sigma),
517520
min_keep=llama_cpp.c_size_t(1),
518521
)
519-
llama_cpp.llama_sample_typical(
522+
llama_cpp.llama_sampler_init_typical(
520523
self.ctx,
521524
candidates_p,
522525
llama_cpp.c_float(self.params.typical_p),
523526
min_keep=llama_cpp.c_size_t(1),
524527
)
525-
llama_cpp.llama_sample_top_p(
528+
llama_cpp.llama_sampler_init_top_p(
526529
self.ctx,
527530
candidates_p,
528531
llama_cpp.c_float(self.params.top_p),
529532
min_keep=llama_cpp.c_size_t(1),
530533
)
531-
llama_cpp.llama_sample_temperature(
534+
llama_cpp.llama_sampler_init_temp(
532535
self.ctx, candidates_p, llama_cpp.c_float(self.params.temp)
533536
)
534537
id = llama_cpp.llama_sample_token(self.ctx, candidates_p)

Diff for: llama_cpp/_internals.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,9 @@ class LlamaSamplingParams:
570570
n_prev: int = 64
571571
n_probs: int = 0
572572
top_k: int = 40
573+
top_n_sigma: float = -1.00
573574
top_p: float = 0.95
574575
min_p: float = 0.05
575-
tfs_z: float = 1.00
576576
typical_p: float = 1.00
577577
temp: float = 0.80
578578
penalty_last_n: int = 64

Diff for: llama_cpp/llama.py

-17
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,6 @@ def _init_sampler(
677677
repeat_penalty: float = 1.0,
678678
frequency_penalty: float = 0.0,
679679
presence_penalty: float = 0.0,
680-
tfs_z: float = 1.0,
681680
mirostat_mode: int = 0,
682681
mirostat_eta: float = 0.1,
683682
mirostat_tau: float = 5.0,
@@ -771,7 +770,6 @@ def sample(
771770
repeat_penalty: float = 1.0,
772771
frequency_penalty: float = 0.0,
773772
presence_penalty: float = 0.0,
774-
tfs_z: float = 1.0,
775773
mirostat_mode: int = 0,
776774
mirostat_eta: float = 0.1,
777775
mirostat_tau: float = 5.0,
@@ -809,7 +807,6 @@ def sample(
809807
repeat_penalty=repeat_penalty,
810808
frequency_penalty=frequency_penalty,
811809
presence_penalty=presence_penalty,
812-
tfs_z=tfs_z,
813810
mirostat_mode=mirostat_mode,
814811
mirostat_tau=mirostat_tau,
815812
mirostat_eta=mirostat_eta,
@@ -841,7 +838,6 @@ def generate(
841838
reset: bool = True,
842839
frequency_penalty: float = 0.0,
843840
presence_penalty: float = 0.0,
844-
tfs_z: float = 1.0,
845841
mirostat_mode: int = 0,
846842
mirostat_tau: float = 5.0,
847843
mirostat_eta: float = 0.1,
@@ -883,7 +879,6 @@ def generate(
883879
repeat_penalty=repeat_penalty,
884880
frequency_penalty=frequency_penalty,
885881
presence_penalty=presence_penalty,
886-
tfs_z=tfs_z,
887882
mirostat_mode=mirostat_mode,
888883
mirostat_tau=mirostat_tau,
889884
mirostat_eta=mirostat_eta,
@@ -938,7 +933,6 @@ def generate(
938933
repeat_penalty=repeat_penalty,
939934
frequency_penalty=frequency_penalty,
940935
presence_penalty=presence_penalty,
941-
tfs_z=tfs_z,
942936
mirostat_mode=mirostat_mode,
943937
mirostat_tau=mirostat_tau,
944938
mirostat_eta=mirostat_eta,
@@ -1157,7 +1151,6 @@ def _create_completion(
11571151
top_n_sigma: float = -1.00,
11581152
stream: bool = False,
11591153
seed: Optional[int] = None,
1160-
tfs_z: float = 1.0,
11611154
mirostat_mode: int = 0,
11621155
mirostat_tau: float = 5.0,
11631156
mirostat_eta: float = 0.1,
@@ -1348,7 +1341,6 @@ def logit_bias_processor(
13481341
min_p=min_p,
13491342
typical_p=typical_p,
13501343
temp=temperature,
1351-
tfs_z=tfs_z,
13521344
mirostat_mode=mirostat_mode,
13531345
mirostat_tau=mirostat_tau,
13541346
mirostat_eta=mirostat_eta,
@@ -1783,7 +1775,6 @@ def create_completion(
17831775
top_n_sigma: float = -1.00,
17841776
stream: bool = False,
17851777
seed: Optional[int] = None,
1786-
tfs_z: float = 1.0,
17871778
mirostat_mode: int = 0,
17881779
mirostat_tau: float = 5.0,
17891780
mirostat_eta: float = 0.1,
@@ -1815,7 +1806,6 @@ def create_completion(
18151806
top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
18161807
stream: Whether to stream the results.
18171808
seed: The seed to use for sampling.
1818-
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
18191809
mirostat_mode: The mirostat sampling mode.
18201810
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
18211811
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
@@ -1852,7 +1842,6 @@ def create_completion(
18521842
top_n_sigma=top_n_sigma,
18531843
stream=stream,
18541844
seed=seed,
1855-
tfs_z=tfs_z,
18561845
mirostat_mode=mirostat_mode,
18571846
mirostat_tau=mirostat_tau,
18581847
mirostat_eta=mirostat_eta,
@@ -1889,7 +1878,6 @@ def __call__(
18891878
top_n_sigma: float = -1.00,
18901879
stream: bool = False,
18911880
seed: Optional[int] = None,
1892-
tfs_z: float = 1.0,
18931881
mirostat_mode: int = 0,
18941882
mirostat_tau: float = 5.0,
18951883
mirostat_eta: float = 0.1,
@@ -1921,7 +1909,6 @@ def __call__(
19211909
top_n_sigma: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1.00, -1.00 = disabled).
19221910
stream: Whether to stream the results.
19231911
seed: The seed to use for sampling.
1924-
tfs_z: The tail-free sampling parameter. Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
19251912
mirostat_mode: The mirostat sampling mode.
19261913
mirostat_tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
19271914
mirostat_eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
@@ -1958,7 +1945,6 @@ def __call__(
19581945
top_n_sigma=top_n_sigma,
19591946
stream=stream,
19601947
seed=seed,
1961-
tfs_z=tfs_z,
19621948
mirostat_mode=mirostat_mode,
19631949
mirostat_tau=mirostat_tau,
19641950
mirostat_eta=mirostat_eta,
@@ -1992,7 +1978,6 @@ def create_chat_completion(
19921978
presence_penalty: float = 0.0,
19931979
frequency_penalty: float = 0.0,
19941980
repeat_penalty: float = 1.0,
1995-
tfs_z: float = 1.0,
19961981
mirostat_mode: int = 0,
19971982
mirostat_tau: float = 5.0,
19981983
mirostat_eta: float = 0.1,
@@ -2029,7 +2014,6 @@ def create_chat_completion(
20292014
presence_penalty: The penalty to apply to tokens based on their presence in the prompt.
20302015
frequency_penalty: The penalty to apply to tokens based on their frequency in the prompt.
20312016
repeat_penalty: The penalty to apply to repeated tokens.
2032-
tfs_z: The tail-free sampling parameter.
20332017
mirostat_mode: The mirostat sampling mode.
20342018
mirostat_tau: The mirostat sampling tau parameter.
20352019
mirostat_eta: The mirostat sampling eta parameter.
@@ -2071,7 +2055,6 @@ def create_chat_completion(
20712055
presence_penalty=presence_penalty,
20722056
frequency_penalty=frequency_penalty,
20732057
repeat_penalty=repeat_penalty,
2074-
tfs_z=tfs_z,
20752058
mirostat_mode=mirostat_mode,
20762059
mirostat_tau=mirostat_tau,
20772060
mirostat_eta=mirostat_eta,

0 commit comments

Comments
 (0)