Skip to content

Commit 166ad5e

Browse files
committed
cont : remove llama_sbatch
ggml-ci
1 parent 1b6dfc5 commit 166ad5e

File tree

4 files changed

+30
-366
lines changed

4 files changed

+30
-366
lines changed

src/llama-batch.cpp

Lines changed: 0 additions & 275 deletions
Original file line numberDiff line numberDiff line change
@@ -9,281 +9,6 @@
99
#include <algorithm>
1010
#include <sstream>
1111

12-
llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
13-
// clear empty sequences
14-
// the previous ubatch is assumed to be gone,
15-
// so nothing should refer to values in these sequences anymore.
16-
for (size_t i = seq.size(); i-- > 0;) {
17-
if (seq[i].length == 0) {
18-
seq.pop_back();
19-
} else {
20-
break;
21-
}
22-
}
23-
24-
udatas.push_back({});
25-
26-
auto & udata = udatas.back();
27-
28-
udata.token.resize(!has_embd ? n_ubatch : 0);
29-
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
30-
udata.pos.resize(n_ubatch);
31-
udata.n_seq_id.resize(n_ubatch);
32-
udata.seq_id.resize(n_ubatch);
33-
udata.output.resize(n_ubatch);
34-
35-
llama_ubatch ubatch = {
36-
/*equal_seqs =*/ true,
37-
/*n_tokens =*/ 0,
38-
/*n_seq_tokens =*/ 0,
39-
/*n_seqs =*/ 0,
40-
/*token =*/ !has_embd ? udata.token.data() : nullptr,
41-
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
42-
/*pos =*/ udata.pos.data(),
43-
/*n_seq_id =*/ udata.n_seq_id.data(),
44-
/*seq_id =*/ udata.seq_id.data(),
45-
/*output =*/ udata.output.data(),
46-
};
47-
48-
return ubatch;
49-
}
50-
51-
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
52-
GGML_ASSERT(batch != nullptr);
53-
GGML_ASSERT(length <= seq.length);
54-
// Can only add sequences of equal lengths to a batch,
55-
// otherwise it isn't clear to which sequence a token belongs
56-
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
57-
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
58-
// NOTE: loops are separated for cache-friendliness
59-
if (batch->token) {
60-
if (ubatch.equal_seqs) {
61-
for (size_t i = 0; i < length; ++i) {
62-
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
63-
}
64-
} else {
65-
// simple split
66-
ubatch.token = batch->token + seq.offset;
67-
}
68-
} else {
69-
ubatch.token = nullptr;
70-
}
71-
if (batch->embd) {
72-
if (ubatch.equal_seqs) {
73-
for (size_t i = 0; i < length; ++i) {
74-
memcpy(
75-
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
76-
batch->embd + (n_embd * ids[seq.offset + i]),
77-
n_embd * sizeof(float)
78-
);
79-
}
80-
} else {
81-
// simple split
82-
ubatch.embd = batch->embd + (n_embd * seq.offset);
83-
}
84-
} else {
85-
ubatch.embd = nullptr;
86-
}
87-
if (ubatch.equal_seqs) {
88-
for (size_t i = 0; i < length; ++i) {
89-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
90-
}
91-
} else {
92-
// simple split
93-
ubatch.pos = batch->pos + seq.offset;
94-
}
95-
if (ubatch.equal_seqs) {
96-
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
97-
if (seq.seq_id) {
98-
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
99-
}
100-
} else {
101-
// simple split
102-
if (batch->n_seq_id) {
103-
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
104-
} else {
105-
for (size_t i = 0; i < length; ++i) {
106-
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
107-
}
108-
}
109-
if (batch->seq_id) {
110-
ubatch.seq_id = batch->seq_id + seq.offset;
111-
}
112-
}
113-
if (batch->logits) {
114-
if (ubatch.equal_seqs) {
115-
for (size_t i = 0; i < length; ++i) {
116-
size_t id = ids[seq.offset + i];
117-
int8_t is_output = batch->logits[id];
118-
ubatch.output[ubatch.n_tokens + i] = is_output;
119-
if (is_output) { out_ids.push_back(id); }
120-
}
121-
} else {
122-
// simple split
123-
ubatch.output = batch->logits + seq.offset;
124-
for (size_t i = 0; i < length; ++i) {
125-
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
126-
}
127-
}
128-
} else {
129-
// only get last output
130-
for (size_t i = 0; i < length; ++i) {
131-
size_t id = ids[seq.offset + i];
132-
int8_t is_last = id == ids.size() - 1;
133-
ubatch.output[ubatch.n_tokens + i] = is_last;
134-
if (is_last) { out_ids.push_back(id); }
135-
}
136-
}
137-
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
138-
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
139-
}
140-
ubatch.n_tokens += length;
141-
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
142-
seq.offset += length;
143-
seq.length -= length;
144-
n_tokens -= length;
145-
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
146-
}
147-
148-
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
149-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
150-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
151-
ubatch.equal_seqs = false;
152-
if (!seq.empty()) {
153-
llama_sbatch_seq & s = seq[0];
154-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
155-
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
156-
add_seq_to_ubatch(ubatch, s, length);
157-
}
158-
return ubatch;
159-
}
160-
161-
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
162-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
163-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
164-
if (!seq.empty()) {
165-
size_t length = 0;
166-
size_t n_tokens_in_ubatch = 0;
167-
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
168-
// smallest first, because it's easier to split this way;
169-
// starting from the end to pop in constant time.
170-
for (size_t i = seq.size(); i-- > 0;) {
171-
llama_sbatch_seq & s = seq[i];
172-
GGML_ASSERT(s.length > 0);
173-
if (length == 0) {
174-
length = s.length < n_ubatch ? s.length : n_ubatch;
175-
}
176-
add_seq_to_ubatch(ubatch, s, length);
177-
n_tokens_in_ubatch += length;
178-
// shared prompts can't be mixed with any of their sequences,
179-
// so it's safer to compute them in their own ubatch
180-
if (s.n_seq_id > 1) { break; }
181-
// stop when there isn't enough space for another sequence
182-
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
183-
}
184-
}
185-
return ubatch;
186-
}
187-
188-
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
189-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
190-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
191-
if (!seq.empty()) {
192-
llama_sbatch_seq & s = seq[seq.size() - 1];
193-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
194-
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
195-
add_seq_to_ubatch(ubatch, s, length);
196-
}
197-
return ubatch;
198-
}
199-
200-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
201-
GGML_ASSERT(batch.n_tokens >= 0);
202-
this->batch = &batch;
203-
this->n_embd = n_embd;
204-
205-
n_tokens = batch.n_tokens;
206-
ids.resize(n_tokens);
207-
out_ids.clear();
208-
// TODO: reserve out_ids and seq
209-
210-
for (size_t i = 0; i < n_tokens; ++i) {
211-
ids[i] = i;
212-
}
213-
214-
if (simple_split) {
215-
seq.resize(1);
216-
llama_sbatch_seq & s = seq[0];
217-
s.n_seq_id = 0;
218-
s.seq_id = nullptr;
219-
s.offset = 0;
220-
s.length = n_tokens;
221-
return;
222-
}
223-
224-
std::sort(ids.begin(), ids.end(),
225-
[&batch](size_t a, size_t b) {
226-
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
227-
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
228-
// sort by seq_id, then by pos
229-
if (n_seq_a == n_seq_b) {
230-
if (batch.seq_id) {
231-
for (int32_t i = 0; i < n_seq_a; ++i) {
232-
llama_seq_id seq_id_a = batch.seq_id[a][i];
233-
llama_seq_id seq_id_b = batch.seq_id[b][i];
234-
// smaller seq_ids go first
235-
if (seq_id_a != seq_id_b) {
236-
return seq_id_a < seq_id_b;
237-
}
238-
}
239-
}
240-
// when all else is equal, sort by pos
241-
if (batch.pos) {
242-
return batch.pos[a] < batch.pos[b];
243-
}
244-
// no pos, sort by id
245-
return a < b;
246-
}
247-
// shared prompts go first
248-
return n_seq_a > n_seq_b;
249-
}
250-
);
251-
252-
// init seq
253-
llama_sbatch_seq * last_seq = nullptr;
254-
255-
for (size_t i = 0; i < n_tokens; ++i) {
256-
const size_t bi = ids[i];
257-
const int32_t n_seqs = batch.n_seq_id[bi];
258-
llama_seq_id * seq_ids = batch.seq_id[bi];
259-
if (last_seq != nullptr) {
260-
bool same = n_seqs == last_seq->n_seq_id;
261-
for (int32_t j = 0; same && j < n_seqs; ++j) {
262-
if (seq_ids[j] != last_seq->seq_id[j]) {
263-
same = false;
264-
}
265-
}
266-
if (same) {
267-
last_seq->length += 1;
268-
continue;
269-
}
270-
}
271-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
272-
seq.push_back(new_seq);
273-
last_seq = &seq.back();
274-
}
275-
276-
// keep shared prompts first at the end, then sort by length descending.
277-
std::sort(seq.begin(), seq.end(),
278-
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
279-
if (a.n_seq_id == b.n_seq_id) {
280-
return a.length > b.length;
281-
}
282-
return a.n_seq_id < b.n_seq_id;
283-
}
284-
);
285-
}
286-
28712
llama_batch_allocr::llama_batch_allocr() {
28813
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
28914
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;

src/llama-batch.h

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,64 +26,6 @@ struct llama_ubatch {
2626
int8_t * output; // [n_tokens]
2727
};
2828

29-
// TODO: remove
30-
struct llama_sbatch_seq {
31-
int32_t n_seq_id;
32-
33-
llama_seq_id * seq_id;
34-
35-
size_t offset;
36-
size_t length;
37-
};
38-
39-
// sequence-length-aware batch splitting
40-
// TODO: remove
41-
struct llama_sbatch {
42-
// tokens left in this batch
43-
size_t n_tokens;
44-
45-
size_t n_embd;
46-
47-
// sorted indices into the batch
48-
std::vector<int64_t> ids;
49-
// batch indices of the output
50-
std::vector<int64_t> out_ids;
51-
std::vector<llama_sbatch_seq> seq;
52-
53-
const llama_batch * batch = nullptr;
54-
55-
// buffers for the ubatches
56-
// TODO: very hacky, this needs a complete rework
57-
struct ubatch_data {
58-
std::vector<llama_token> token;
59-
std::vector<float> embd;
60-
std::vector<llama_pos> pos;
61-
std::vector<int32_t> n_seq_id;
62-
std::vector<llama_seq_id *> seq_id;
63-
std::vector<int8_t> output;
64-
};
65-
66-
std::vector<ubatch_data> udatas;
67-
68-
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
69-
70-
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
71-
72-
// simple split, unknown number of sequences of unequal lengths
73-
llama_ubatch split_simple(size_t n_ubatch);
74-
75-
// make batches of equal-length sequences
76-
llama_ubatch split_equal(size_t n_ubatch);
77-
78-
// sequence-wise split
79-
llama_ubatch split_seq(size_t n_ubatch);
80-
81-
llama_sbatch() = default;
82-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
83-
};
84-
85-
// ---------------------
86-
8729
// a helper for sanitizing, fulfilling and splitting a batch
8830
class llama_batch_allocr {
8931
public:

0 commit comments

Comments
 (0)