diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 1d55ff12f..37a022015 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -105,12 +105,7 @@ public static async Task Run() if (i_batch[i] < 0) continue; - var n_vocab = model.VocabCount; - LLamaTokenDataArray candidates; - unsafe - { - candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); - } + var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i])); candidates.TopK(context.NativeHandle, top_k); candidates.TopP(context.NativeHandle, top_p); diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 8d4be20cb..cfe499734 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -19,6 +19,7 @@ public StatelessExecutorTest(ITestOutputHelper testOutputHelper) { ContextSize = 60, Seed = 1754, + BatchSize = 2, }; _weights = LLamaWeights.LoadFromFile(_params); } @@ -60,7 +61,7 @@ public async Task OutOfContext() { var executor = new StatelessExecutor(_weights, _params); - const string question = " Question. cats or dogs?\nAnswer: "; + const string question = " Question. cats or dogs?\nAnswer:"; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response // with a modified context diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index c56d78ffc..0feb53665 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -1,4 +1,5 @@ using System; +using LLama.Native; namespace LLama.Exceptions; @@ -36,4 +37,23 @@ public LoadWeightsFailedException(string modelPath) { ModelPath = modelPath; } +} + +/// +/// `llama_decode` return a non-zero status code +/// +public class LLamaDecodeError + : RuntimeError +{ + /// + /// The return status code + /// + public DecodeResult ReturnCode { get; } + + /// + public LLamaDecodeError(DecodeResult returnCode) + : base($"llama_decode failed: '{returnCode}'") + { + ReturnCode = returnCode; + } } \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ea745d029..33c8d7260 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -293,6 +293,7 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, /// /// Apply the penalty for the tokens. Please don't use it unless you fully know what it does. /// + /// /// /// /// @@ -301,11 +302,11 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, /// /// /// - public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dictionary? logitBias = null, - int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, - bool penalizeNL = true) + public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable lastTokens, Dictionary? logitBias = null, + int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, + bool penalizeNL = true) { - var logits = NativeHandle.GetLogits(); + var logits = NativeHandle.GetLogitsIth(logits_i); // Apply params.logit_bias map if (logitBias is not null) @@ -348,28 +349,23 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dict /// /// /// - /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
- public int Decode(LLamaBatch batch) + public DecodeResult Decode(LLamaBatch batch) { - return NativeHandle.Decode(batch); + if (batch.TokenCount == 0) + return 0; + if (batch.TokenCount > Params.BatchSize) + throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch)); + + return (DecodeResult)NativeHandle.Decode(batch); } /// /// /// /// - /// Positive return values does not mean a fatal error, but rather a warning:
- /// - 0: success
- /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
- /// - < 0: error
- ///
- public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) + public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) { - return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); + return Task.Run(() => Decode(batch), cancellationToken); } /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index b763145eb..993019f18 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -216,7 +216,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } else { - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 11973a273..2e72c7ae8 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -195,7 +195,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In } else { - var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); var mu = MirostatMu; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index e03fe7a1c..0587f148c 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -5,7 +5,7 @@ using System.Linq; using System.Runtime.CompilerServices; using System.Threading; -using System.Threading.Tasks; +using LLama.Exceptions; using LLama.Native; using LLama.Sampling; using Microsoft.Extensions.Logging; @@ -22,6 +22,7 @@ public class StatelessExecutor private readonly LLamaWeights _weights; private readonly IContextParams _params; private readonly ILogger? _logger; + private readonly LLamaBatch _batch; /// /// The context used by the executor when running the inference. @@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? _weights = weights; _params = @params; _logger = logger; + _batch = new LLamaBatch(1); Context = _weights.CreateContext(_params, logger); Context.Dispose(); @@ -71,16 +73,29 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); var lastTokens = new List(repeat_last_n); for (var i = 0; i < repeat_last_n; i++) - lastTokens.Add((LLamaToken)0); + lastTokens.Add(0); // Tokenize the prompt var tokens = Context.Tokenize(prompt).ToList(); lastTokens.AddRange(tokens); - var n_past = 1 + tokens.Count; - // Evaluate the prompt - await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) - .ConfigureAwait(false); + // Evaluate the prompt, in chunks smaller than the max batch size + var n_past = 0; + var batchSize = (int)Context.Params.BatchSize; + for (var i = 0; i < tokens.Count; i += batchSize) + { + var n_eval = tokens.Count - i; + if (n_eval > batchSize) + n_eval = batchSize; + + _batch.Clear(); + for (var j = 0; j < n_eval; j++) + _batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1); + + var returnCode = await Context.DecodeAsync(_batch, cancellationToken); + if (returnCode != 0) + throw new LLamaDecodeError(returnCode); + } // Begin loop, evaluating one token at a time var mu = (float?)null; @@ -90,12 +105,12 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams LLamaToken id; if (inferenceParams.SamplingPipeline is not null) { - id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens); + id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens); } else { // Penalize the generated tokens by various penalties - var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, + var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); // Sample a single token @@ -136,9 +151,12 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams n_past -= n_discard; } - // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) - n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken) - .ConfigureAwait(false); + // Evaluate with this new token + _batch.Clear(); + _batch.Add(id, n_past++, LLamaSeqId.Zero, true); + var returnCode = await context.DecodeAsync(_batch, cancellationToken); + if (returnCode != 0) + throw new LLamaDecodeError(returnCode); } } } diff --git a/LLama/Native/DecodeResult.cs b/LLama/Native/DecodeResult.cs new file mode 100644 index 000000000..61056dd9d --- /dev/null +++ b/LLama/Native/DecodeResult.cs @@ -0,0 +1,22 @@ +namespace LLama.Native; + +/// +/// Return codes from llama_decode +/// +public enum DecodeResult +{ + /// + /// An unspecified error + /// + Error = -1, + + /// + /// Ok. + /// + Ok = 0, + + /// + /// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + /// + NoKvSlot = 1, +} \ No newline at end of file