|
2 | 2 |
|
3 | 3 | #include "llama.h"
|
4 | 4 |
|
| 5 | +#include "llama-cparams.h" |
| 6 | + |
5 | 7 | #include <array>
|
6 | 8 | #include <vector>
|
7 | 9 | #include <set>
|
| 10 | +#include <bitset> |
| 11 | +#include <unordered_map> |
8 | 12 |
|
9 |
| -// very similar to llama_batch, |
10 |
| -// but has more metadata about sequences |
| 13 | +// keep this struct lightweight |
| 14 | +// it points to data in `llama_batch_allocr` |
11 | 15 | struct llama_ubatch {
|
12 | 16 | bool equal_seqs;
|
13 | 17 | // TODO: whole_seqs for embeddings?
|
14 | 18 |
|
15 | 19 | uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
16 |
| - uint32_t n_seq_tokens; // tokens per sequence |
17 |
| - uint32_t n_seqs; |
18 |
| - |
19 |
| - llama_token * token; // [n_tokens] |
20 |
| - float * embd; // [n_embd, n_tokens] |
21 |
| - llama_pos * pos; // [n_tokens] |
22 |
| - int32_t * n_seq_id; // [n_seqs] |
23 |
| - llama_seq_id ** seq_id; // [n_seqs] |
24 |
| - int8_t * output; // [n_tokens] |
25 |
| -}; |
26 |
| - |
27 |
| -struct llama_sbatch_seq { |
28 |
| - int32_t n_seq_id; |
29 |
| - |
30 |
| - llama_seq_id * seq_id; |
31 |
| - |
32 |
| - size_t offset; |
33 |
| - size_t length; |
34 |
| -}; |
35 |
| - |
36 |
| -// sequence-length-aware batch splitting |
37 |
| -struct llama_sbatch { |
38 |
| - // tokens left in this batch |
39 |
| - size_t n_tokens; |
40 |
| - |
41 |
| - size_t n_embd; |
42 |
| - |
43 |
| - // sorted indices into the batch |
44 |
| - std::vector<int64_t> ids; |
45 |
| - // batch indices of the output |
46 |
| - std::vector<int64_t> out_ids; |
47 |
| - std::vector<llama_sbatch_seq> seq; |
48 |
| - |
49 |
| - const llama_batch * batch = nullptr; |
50 |
| - |
51 |
| - // buffers for the ubatches |
52 |
| - // TODO: very hacky, this needs a complete rework |
53 |
| - struct ubatch_data { |
54 |
| - std::vector<llama_token> token; |
55 |
| - std::vector<float> embd; |
56 |
| - std::vector<llama_pos> pos; |
57 |
| - std::vector<int32_t> n_seq_id; |
58 |
| - std::vector<llama_seq_id *> seq_id; |
59 |
| - std::vector<int8_t> output; |
60 |
| - }; |
61 |
| - |
62 |
| - std::vector<ubatch_data> udatas; |
63 |
| - |
64 |
| - llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); |
65 |
| - |
66 |
| - void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length); |
67 |
| - |
68 |
| - // simple split, unknown number of sequences of unequal lengths |
69 |
| - llama_ubatch split_simple(size_t n_ubatch); |
70 |
| - |
71 |
| - // make batches of equal-length sequences |
72 |
| - llama_ubatch split_equal(size_t n_ubatch); |
73 |
| - |
74 |
| - // sequence-wise split |
75 |
| - llama_ubatch split_seq(size_t n_ubatch); |
76 |
| - |
77 |
| - llama_sbatch() = default; |
78 |
| - llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); |
| 20 | + uint32_t n_seq_tokens; // tokens per sequence set |
| 21 | + uint32_t n_seqs; // sequence sets in the ubatch |
| 22 | + uint32_t n_seqs_unq; // unique sequence ids in the ubatch |
| 23 | + |
| 24 | + // seq_id_unq: unique sequence ids in the ubatch |
| 25 | + // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq) |
| 26 | + // used for extracting sequence pooled embeddings |
| 27 | + |
| 28 | + // // size | idx | val |
| 29 | + llama_token * token; // [n_tokens] | i | id, token |
| 30 | + float * embd; // [n_embd, n_tokens] | i | embd |
| 31 | + llama_pos * pos; // [n_tokens] | i | pos |
| 32 | + int32_t * n_seq_id; // [n_tokens] | i | - |
| 33 | + llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id |
| 34 | + llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id |
| 35 | + int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx |
| 36 | + int8_t * output; // [n_tokens] | i | - |
79 | 37 | };
|
80 | 38 |
|
81 |
| -// a helper for sanitizing and fulfilling a batch |
| 39 | +// a helper for sanitizing, fulfilling and splitting a batch |
82 | 40 | class llama_batch_allocr {
|
83 | 41 | public:
|
84 |
| - llama_batch_allocr(); |
| 42 | + llama_batch_allocr(uint32_t n_pos_per_embd); |
85 | 43 |
|
86 | 44 | // sanitize and auto-gen missing data in the input batch
|
87 | 45 | // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
88 | 46 | bool init(
|
89 | 47 | const llama_batch & batch_inp,
|
90 | 48 | const llama_vocab & vocab,
|
91 | 49 | const llama_memory_i * memory,
|
92 |
| - bool embd_all); |
| 50 | + uint32_t n_embd, |
| 51 | + bool output_all); |
93 | 52 |
|
94 | 53 | const llama_batch & get_batch() const;
|
95 | 54 |
|
| 55 | + uint32_t get_n_tokens() const; |
96 | 56 | uint32_t get_n_outputs() const;
|
97 | 57 |
|
| 58 | + // the array of output indices in the order they were encountered during the ubatch splitting |
| 59 | + std::vector<int32_t> & get_out_ids(); |
| 60 | + |
| 61 | + // min/max positions of each sequence in the current ubatch |
98 | 62 | llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
99 | 63 | llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
100 | 64 |
|
| 65 | + // call once before splitting the batch to reset the internal state |
| 66 | + void split_reset(); |
| 67 | + |
| 68 | + // simple split, unknown number of sequence sets of unequal lengths |
| 69 | + llama_ubatch split_simple(uint32_t n_ubatch); |
| 70 | + |
| 71 | + // make ubatches of equal-length sequences sets |
| 72 | + llama_ubatch split_equal(uint32_t n_ubatch); |
| 73 | + |
| 74 | + // sequence-set-wise split - each ubatch contains a single sequence-set |
| 75 | + llama_ubatch split_seq(uint32_t n_ubatch); |
| 76 | + |
| 77 | + // a helper method for creating a well-defined ubatch of tokens |
| 78 | + // TODO: support embeddings if needed in the future |
| 79 | + llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs); |
| 80 | + |
101 | 81 | private:
|
102 | 82 | void clear();
|
103 | 83 |
|
| 84 | + // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs) |
| 85 | + // return llama_ubatch.n_tokens == 0 if the entire batch was consumed |
| 86 | + llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs); |
| 87 | + |
| 88 | + // for debugging, start with LLAMA_BATCH_DEBUG=2 |
| 89 | + void ubatch_print(const llama_ubatch & ubatch, int debug); |
| 90 | + |
104 | 91 | llama_batch batch;
|
105 | 92 |
|
| 93 | + // only for debugging purposes |
| 94 | + const llama_vocab * vocab; |
| 95 | + |
| 96 | + // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd |
| 97 | + // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762 |
| 98 | + const uint32_t n_pos_per_embd; |
| 99 | + |
| 100 | + uint32_t n_embd; |
106 | 101 | uint32_t n_outputs;
|
107 | 102 |
|
108 | 103 | std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
109 | 104 |
|
110 | 105 | std::vector<llama_pos> pos;
|
111 | 106 | std::vector<int32_t> n_seq_id;
|
112 | 107 | std::vector<llama_seq_id *> seq_id;
|
| 108 | + std::vector<llama_seq_id> seq_id_unq; |
| 109 | + std::vector<int32_t> seq_idx; |
113 | 110 | std::vector<int8_t> output;
|
114 | 111 |
|
115 |
| - std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s |
116 |
| - std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |
| 112 | + using pos_set_t = std::set<llama_pos>; |
| 113 | + using seq_cpl_t = std::vector<bool>; |
| 114 | + |
| 115 | + std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s |
| 116 | + std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 |
| 117 | + |
| 118 | + using idx_vec_t = std::vector<int32_t>; |
| 119 | + using seq_set_t = std::bitset<LLAMA_MAX_SEQ>; |
| 120 | + |
| 121 | + std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i |
| 122 | + |
| 123 | + std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears |
| 124 | + |
| 125 | + // batch indices of the output |
| 126 | + std::vector<int32_t> out_ids; |
| 127 | + |
| 128 | + // used[i] indicates if token i has already been used in a previous ubatch |
| 129 | + std::vector<bool> used; |
| 130 | + |
| 131 | + // llama_ubatch points to this data: |
| 132 | + struct ubatch { |
| 133 | + std::vector<llama_token> token; |
| 134 | + std::vector<float> embd; |
| 135 | + std::vector<llama_pos> pos; |
| 136 | + std::vector<int32_t> n_seq_id; |
| 137 | + std::vector<llama_seq_id *> seq_id; |
| 138 | + std::vector<llama_seq_id> seq_id_unq; |
| 139 | + std::vector<int32_t> seq_idx; |
| 140 | + std::vector<int8_t> output; |
| 141 | + }; |
| 142 | + |
| 143 | + // current splitting state: |
| 144 | + std::vector<ubatch> ubatches; |
117 | 145 |
|
118 | 146 | int debug;
|
119 | 147 | };
|
0 commit comments