|
9 | 9 | #include <algorithm>
|
10 | 10 | #include <sstream>
|
11 | 11 |
|
12 |
| -llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { |
13 |
| - // clear empty sequences |
14 |
| - // the previous ubatch is assumed to be gone, |
15 |
| - // so nothing should refer to values in these sequences anymore. |
16 |
| - for (size_t i = seq.size(); i-- > 0;) { |
17 |
| - if (seq[i].length == 0) { |
18 |
| - seq.pop_back(); |
19 |
| - } else { |
20 |
| - break; |
21 |
| - } |
22 |
| - } |
23 |
| - |
24 |
| - udatas.push_back({}); |
25 |
| - |
26 |
| - auto & udata = udatas.back(); |
27 |
| - |
28 |
| - udata.token.resize(!has_embd ? n_ubatch : 0); |
29 |
| - udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); |
30 |
| - udata.pos.resize(n_ubatch); |
31 |
| - udata.n_seq_id.resize(n_ubatch); |
32 |
| - udata.seq_id.resize(n_ubatch); |
33 |
| - udata.output.resize(n_ubatch); |
34 |
| - |
35 |
| - llama_ubatch ubatch = { |
36 |
| - /*equal_seqs =*/ true, |
37 |
| - /*n_tokens =*/ 0, |
38 |
| - /*n_seq_tokens =*/ 0, |
39 |
| - /*n_seqs =*/ 0, |
40 |
| - /*token =*/ !has_embd ? udata.token.data() : nullptr, |
41 |
| - /*embd =*/ has_embd ? udata.embd.data() : nullptr, |
42 |
| - /*pos =*/ udata.pos.data(), |
43 |
| - /*n_seq_id =*/ udata.n_seq_id.data(), |
44 |
| - /*seq_id =*/ udata.seq_id.data(), |
45 |
| - /*output =*/ udata.output.data(), |
46 |
| - }; |
47 |
| - |
48 |
| - return ubatch; |
49 |
| -} |
50 |
| - |
51 |
| -void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { |
52 |
| - GGML_ASSERT(batch != nullptr); |
53 |
| - GGML_ASSERT(length <= seq.length); |
54 |
| - // Can only add sequences of equal lengths to a batch, |
55 |
| - // otherwise it isn't clear to which sequence a token belongs |
56 |
| - GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); |
57 |
| - GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); |
58 |
| - // NOTE: loops are separated for cache-friendliness |
59 |
| - if (batch->token) { |
60 |
| - if (ubatch.equal_seqs) { |
61 |
| - for (size_t i = 0; i < length; ++i) { |
62 |
| - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; |
63 |
| - } |
64 |
| - } else { |
65 |
| - // simple split |
66 |
| - ubatch.token = batch->token + seq.offset; |
67 |
| - } |
68 |
| - } else { |
69 |
| - ubatch.token = nullptr; |
70 |
| - } |
71 |
| - if (batch->embd) { |
72 |
| - if (ubatch.equal_seqs) { |
73 |
| - for (size_t i = 0; i < length; ++i) { |
74 |
| - memcpy( |
75 |
| - ubatch.embd + (n_embd * (ubatch.n_tokens + i)), |
76 |
| - batch->embd + (n_embd * ids[seq.offset + i]), |
77 |
| - n_embd * sizeof(float) |
78 |
| - ); |
79 |
| - } |
80 |
| - } else { |
81 |
| - // simple split |
82 |
| - ubatch.embd = batch->embd + (n_embd * seq.offset); |
83 |
| - } |
84 |
| - } else { |
85 |
| - ubatch.embd = nullptr; |
86 |
| - } |
87 |
| - if (ubatch.equal_seqs) { |
88 |
| - for (size_t i = 0; i < length; ++i) { |
89 |
| - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; |
90 |
| - } |
91 |
| - } else { |
92 |
| - // simple split |
93 |
| - ubatch.pos = batch->pos + seq.offset; |
94 |
| - } |
95 |
| - if (ubatch.equal_seqs) { |
96 |
| - ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; |
97 |
| - if (seq.seq_id) { |
98 |
| - ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; |
99 |
| - } |
100 |
| - } else { |
101 |
| - // simple split |
102 |
| - if (batch->n_seq_id) { |
103 |
| - ubatch.n_seq_id = batch->n_seq_id + seq.offset; |
104 |
| - } else { |
105 |
| - for (size_t i = 0; i < length; ++i) { |
106 |
| - ubatch.n_seq_id[ubatch.n_seqs + i] = 1; |
107 |
| - } |
108 |
| - } |
109 |
| - if (batch->seq_id) { |
110 |
| - ubatch.seq_id = batch->seq_id + seq.offset; |
111 |
| - } |
112 |
| - } |
113 |
| - if (batch->logits) { |
114 |
| - if (ubatch.equal_seqs) { |
115 |
| - for (size_t i = 0; i < length; ++i) { |
116 |
| - size_t id = ids[seq.offset + i]; |
117 |
| - int8_t is_output = batch->logits[id]; |
118 |
| - ubatch.output[ubatch.n_tokens + i] = is_output; |
119 |
| - if (is_output) { out_ids.push_back(id); } |
120 |
| - } |
121 |
| - } else { |
122 |
| - // simple split |
123 |
| - ubatch.output = batch->logits + seq.offset; |
124 |
| - for (size_t i = 0; i < length; ++i) { |
125 |
| - if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } |
126 |
| - } |
127 |
| - } |
128 |
| - } else { |
129 |
| - // only get last output |
130 |
| - for (size_t i = 0; i < length; ++i) { |
131 |
| - size_t id = ids[seq.offset + i]; |
132 |
| - int8_t is_last = id == ids.size() - 1; |
133 |
| - ubatch.output[ubatch.n_tokens + i] = is_last; |
134 |
| - if (is_last) { out_ids.push_back(id); } |
135 |
| - } |
136 |
| - } |
137 |
| - if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { |
138 |
| - ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; |
139 |
| - } |
140 |
| - ubatch.n_tokens += length; |
141 |
| - ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits |
142 |
| - seq.offset += length; |
143 |
| - seq.length -= length; |
144 |
| - n_tokens -= length; |
145 |
| - GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); |
146 |
| -} |
147 |
| - |
148 |
| -llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { |
149 |
| - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
150 |
| - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
151 |
| - ubatch.equal_seqs = false; |
152 |
| - if (!seq.empty()) { |
153 |
| - llama_sbatch_seq & s = seq[0]; |
154 |
| - size_t length = s.length < n_ubatch ? s.length : n_ubatch; |
155 |
| - GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits |
156 |
| - add_seq_to_ubatch(ubatch, s, length); |
157 |
| - } |
158 |
| - return ubatch; |
159 |
| -} |
160 |
| - |
161 |
| -llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { |
162 |
| - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
163 |
| - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
164 |
| - if (!seq.empty()) { |
165 |
| - size_t length = 0; |
166 |
| - size_t n_tokens_in_ubatch = 0; |
167 |
| - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits |
168 |
| - // smallest first, because it's easier to split this way; |
169 |
| - // starting from the end to pop in constant time. |
170 |
| - for (size_t i = seq.size(); i-- > 0;) { |
171 |
| - llama_sbatch_seq & s = seq[i]; |
172 |
| - GGML_ASSERT(s.length > 0); |
173 |
| - if (length == 0) { |
174 |
| - length = s.length < n_ubatch ? s.length : n_ubatch; |
175 |
| - } |
176 |
| - add_seq_to_ubatch(ubatch, s, length); |
177 |
| - n_tokens_in_ubatch += length; |
178 |
| - // shared prompts can't be mixed with any of their sequences, |
179 |
| - // so it's safer to compute them in their own ubatch |
180 |
| - if (s.n_seq_id > 1) { break; } |
181 |
| - // stop when there isn't enough space for another sequence |
182 |
| - if (length + n_tokens_in_ubatch > n_ubatch) { break; } |
183 |
| - } |
184 |
| - } |
185 |
| - return ubatch; |
186 |
| -} |
187 |
| - |
188 |
| -llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { |
189 |
| - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
190 |
| - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
191 |
| - if (!seq.empty()) { |
192 |
| - llama_sbatch_seq & s = seq[seq.size() - 1]; |
193 |
| - size_t length = s.length < n_ubatch ? s.length : n_ubatch; |
194 |
| - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits |
195 |
| - add_seq_to_ubatch(ubatch, s, length); |
196 |
| - } |
197 |
| - return ubatch; |
198 |
| -} |
199 |
| - |
200 |
| -llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { |
201 |
| - GGML_ASSERT(batch.n_tokens >= 0); |
202 |
| - this->batch = &batch; |
203 |
| - this->n_embd = n_embd; |
204 |
| - |
205 |
| - n_tokens = batch.n_tokens; |
206 |
| - ids.resize(n_tokens); |
207 |
| - out_ids.clear(); |
208 |
| - // TODO: reserve out_ids and seq |
209 |
| - |
210 |
| - for (size_t i = 0; i < n_tokens; ++i) { |
211 |
| - ids[i] = i; |
212 |
| - } |
213 |
| - |
214 |
| - if (simple_split) { |
215 |
| - seq.resize(1); |
216 |
| - llama_sbatch_seq & s = seq[0]; |
217 |
| - s.n_seq_id = 0; |
218 |
| - s.seq_id = nullptr; |
219 |
| - s.offset = 0; |
220 |
| - s.length = n_tokens; |
221 |
| - return; |
222 |
| - } |
223 |
| - |
224 |
| - std::sort(ids.begin(), ids.end(), |
225 |
| - [&batch](size_t a, size_t b) { |
226 |
| - int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; |
227 |
| - int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; |
228 |
| - // sort by seq_id, then by pos |
229 |
| - if (n_seq_a == n_seq_b) { |
230 |
| - if (batch.seq_id) { |
231 |
| - for (int32_t i = 0; i < n_seq_a; ++i) { |
232 |
| - llama_seq_id seq_id_a = batch.seq_id[a][i]; |
233 |
| - llama_seq_id seq_id_b = batch.seq_id[b][i]; |
234 |
| - // smaller seq_ids go first |
235 |
| - if (seq_id_a != seq_id_b) { |
236 |
| - return seq_id_a < seq_id_b; |
237 |
| - } |
238 |
| - } |
239 |
| - } |
240 |
| - // when all else is equal, sort by pos |
241 |
| - if (batch.pos) { |
242 |
| - return batch.pos[a] < batch.pos[b]; |
243 |
| - } |
244 |
| - // no pos, sort by id |
245 |
| - return a < b; |
246 |
| - } |
247 |
| - // shared prompts go first |
248 |
| - return n_seq_a > n_seq_b; |
249 |
| - } |
250 |
| - ); |
251 |
| - |
252 |
| - // init seq |
253 |
| - llama_sbatch_seq * last_seq = nullptr; |
254 |
| - |
255 |
| - for (size_t i = 0; i < n_tokens; ++i) { |
256 |
| - const size_t bi = ids[i]; |
257 |
| - const int32_t n_seqs = batch.n_seq_id[bi]; |
258 |
| - llama_seq_id * seq_ids = batch.seq_id[bi]; |
259 |
| - if (last_seq != nullptr) { |
260 |
| - bool same = n_seqs == last_seq->n_seq_id; |
261 |
| - for (int32_t j = 0; same && j < n_seqs; ++j) { |
262 |
| - if (seq_ids[j] != last_seq->seq_id[j]) { |
263 |
| - same = false; |
264 |
| - } |
265 |
| - } |
266 |
| - if (same) { |
267 |
| - last_seq->length += 1; |
268 |
| - continue; |
269 |
| - } |
270 |
| - } |
271 |
| - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; |
272 |
| - seq.push_back(new_seq); |
273 |
| - last_seq = &seq.back(); |
274 |
| - } |
275 |
| - |
276 |
| - // keep shared prompts first at the end, then sort by length descending. |
277 |
| - std::sort(seq.begin(), seq.end(), |
278 |
| - [](llama_sbatch_seq & a, llama_sbatch_seq & b) { |
279 |
| - if (a.n_seq_id == b.n_seq_id) { |
280 |
| - return a.length > b.length; |
281 |
| - } |
282 |
| - return a.n_seq_id < b.n_seq_id; |
283 |
| - } |
284 |
| - ); |
285 |
| -} |
286 |
| - |
287 | 12 | llama_batch_allocr::llama_batch_allocr() {
|
288 | 13 | const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
289 | 14 | debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
|
0 commit comments