diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs deleted file mode 100644 index 88b25672e..000000000 --- a/LLama.Unittest/BeamTests.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System.Text; -using LLama.Common; -using LLama.Native; -using Xunit.Abstractions; - -namespace LLama.Unittest; - -public sealed class BeamTests - : IDisposable -{ - private readonly ITestOutputHelper _testOutputHelper; - private readonly ModelParams _params; - private readonly LLamaWeights _model; - - public BeamTests(ITestOutputHelper testOutputHelper) - { - _testOutputHelper = testOutputHelper; - _params = new ModelParams(Constants.GenerativeModelPath) - { - ContextSize = 2048, - GpuLayerCount = Constants.CIGpuLayerCount, - }; - _model = LLamaWeights.LoadFromFile(_params); - } - - public void Dispose() - { - _model.Dispose(); - } - - [Fact] - public void BasicBeam() - { - const int num_beams = 2; - const int n_predict = 3; - const string prompt = "The cat sat on"; - - var context = _model.CreateContext(_params); - - var initial_tokens = context.Tokenize(prompt); - var batch = new LLamaBatch(); - batch.AddRange(initial_tokens, 0, LLamaSeqId.Zero, true); - context.Decode(batch); - - var decoder = new StreamingTokenDecoder(context); - NativeApi.llama_beam_search(context.NativeHandle, (data, state) => - { - // Show the current state of every beam. - for (var i = 0; i < state.Beams.Length; i++) - { - ref var view = ref state.Beams[i]; - - var decoder = new StreamingTokenDecoder(context); - decoder.AddRange(view.Tokens); - var tokens = decoder.Read(); - - _testOutputHelper.WriteLine($"B{i} ({view.CumulativeProbability}) => '{tokens}'"); - } - - // Once all beams agree on some tokens read them and append them to the output decoder - if (state.CommonPrefixLength > 0) - { - var view = state.Beams[0]; - - decoder.AddRange(view.Tokens.Slice(0, (int)state.CommonPrefixLength)); - - } - - }, IntPtr.Zero, num_beams, initial_tokens.Length, n_predict, Math.Max(1, Environment.ProcessorCount / 2)); - - _testOutputHelper.WriteLine($"Final: {prompt}{decoder.Read()}"); - } -} \ No newline at end of file diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs deleted file mode 100644 index dcd583ba3..000000000 --- a/LLama/Native/LLamaBeamView.cs +++ /dev/null @@ -1,40 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace LLama.Native; - -/// -/// Information about a single beam in a beam search -/// -[StructLayout(LayoutKind.Sequential)] -public struct LLamaBeamView -{ - private unsafe LLamaToken* tokens; - private nuint n_tokens; - - /// - /// Cumulative beam probability (renormalized relative to all beams) - /// - public float CumulativeProbability; - - /// - /// Callback should set this to true when a beam is at end-of-beam. - /// - public bool EndOfBeam; - - /// - /// Tokens in this beam - /// - public readonly Span Tokens - { - get - { - unsafe - { - if (n_tokens > int.MaxValue) - throw new InvalidOperationException("More than 2147483647 tokens is not supported"); - return new Span(tokens, (int)n_tokens); - } - } - } -} \ No newline at end of file diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs deleted file mode 100644 index cb214aef3..000000000 --- a/LLama/Native/LLamaBeamsState.cs +++ /dev/null @@ -1,49 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace LLama.Native; - -/// -/// Passed to beam_search_callback function. -/// Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams -/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. -/// -[StructLayout(LayoutKind.Sequential)] -public struct LLamaBeamsState -{ - /// - /// The state of each individual beam - /// - private unsafe LLamaBeamView* beam_views; - - /// - /// Number of elements in beam_views - /// - private nuint n_beams; - - /// - /// Current max length of prefix tokens shared by all beams. - /// - public ulong CommonPrefixLength; - - /// - /// True iff this is the last callback invocation. - /// - public bool LastCall; - - /// - /// The current state of each beam - /// - public Span Beams - { - get - { - unsafe - { - if (n_beams > int.MaxValue) - throw new InvalidOperationException("More than 2147483647 beams is not supported"); - return new Span(beam_views, (int)n_beams); - } - } - } -} \ No newline at end of file diff --git a/LLama/Native/NativeApi.BeamSearch.cs b/LLama/Native/NativeApi.BeamSearch.cs deleted file mode 100644 index 142b997bb..000000000 --- a/LLama/Native/NativeApi.BeamSearch.cs +++ /dev/null @@ -1,25 +0,0 @@ -using System; -using System.Runtime.InteropServices; - -namespace LLama.Native; - -public static partial class NativeApi -{ - /// - /// Type of pointer to the beam_search_callback function. - /// - /// callback_data is any custom data passed to llama_beam_search, that is subsequently passed back to beam_search_callbac - /// - public delegate void LLamaBeamSearchCallback(IntPtr callback_data, LLamaBeamsState state); - - /// Deterministically returns entire sentence constructed by a beam search. - /// Pointer to the llama_context. - /// Invoked for each iteration of the beam_search loop, passing in beams_state. - /// A pointer that is simply passed back to callback. - /// Number of beams to use. - /// Number of tokens already evaluated. - /// Maximum number of tokens to predict. EOS may occur earlier. - /// Number of threads. - [DllImport(libraryName, EntryPoint = "llama_beam_search", CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_beam_search(SafeLLamaContextHandle ctx, LLamaBeamSearchCallback callback, IntPtr callback_data, ulong n_beams, int n_past, int n_predict, int n_threads); -} \ No newline at end of file