From 8d73845d44a822e92c0ce84784438d0ceb413b0d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 6 May 2025 10:01:05 -0700 Subject: [PATCH] sampling: Don't consider -infinity values in top_n_sigma --- src/llama-sampling.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0c9c6a3102929..2869f60d204a1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1757,20 +1757,28 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t // find max logit and calculate mean float max = cur_p->data[0].logit; float logits_sum = 0; + size_t valid_count = 0; for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; } - logits_sum += cur_p->data[i].logit; } - float mean = logits_sum/cur_p->size; + float mean = valid_count > 0 ? logits_sum/valid_count : 0; // calculate standard deviation float acc = 0; for (size_t i = 0; i < cur_p->size; ++i) { - acc += pow(cur_p->data[i].logit - mean, 2); + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } } - float std = sqrt(acc/cur_p->size); + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; //apply mask for (size_t i = 0; i < cur_p->size; ++i) {