Skip to content

Commit 57c79a9

Browse files
committed
cont : rework pooling
ggml-ci
1 parent 7eb5ad4 commit 57c79a9

File tree

7 files changed

+239
-184
lines changed

7 files changed

+239
-184
lines changed

src/llama-batch.cpp

Lines changed: 113 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ llama_batch_allocr::llama_batch_allocr() {
1818
for (auto & cur : seq_cpl) {
1919
cur.resize(LLAMA_MAX_SEQ);
2020
}
21+
22+
seq_idx.resize(LLAMA_MAX_SEQ, -1);
2123
}
2224

2325
bool llama_batch_allocr::init(
@@ -137,22 +139,23 @@ bool llama_batch_allocr::init(
137139
// compute stats
138140
//
139141

142+
this->n_embd = n_embd;
143+
140144
for (int32_t i = 0; i < batch.n_tokens; ++i) {
141145
n_outputs += batch.logits[i] != 0;
142146
}
143147

144-
this->n_embd = n_embd;
145-
146148
// determine coupled sequences
147149
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
148150
for (int32_t i = 0; i < batch.n_tokens; ++i) {
151+
const llama_seq_id s0 = batch.seq_id[i][0];
152+
149153
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
150-
seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]);
154+
const llama_seq_id s1 = batch.seq_id[i][s];
151155

152-
if (s > 0) {
153-
const llama_seq_id s0 = batch.seq_id[i][0];
154-
const llama_seq_id s1 = batch.seq_id[i][s];
156+
seq_pos[s1].insert(batch.pos[i]);
155157

158+
if (s > 0) {
156159
// mark that sequence s1 is coupled to s0
157160
seq_cpl[s1][s0] = true;
158161

@@ -162,14 +165,28 @@ bool llama_batch_allocr::init(
162165
}
163166
}
164167

165-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
166-
seq_set_t cur;
167-
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
168-
cur.set(batch.seq_id[i][s]);
168+
{
169+
seq_set_t seq_set_unq;
170+
171+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
172+
seq_set_t cur;
173+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
174+
const llama_seq_id s0 = batch.seq_id[i][s];
175+
176+
cur.set(s0);
177+
seq_set_unq.set(s0);
178+
}
179+
180+
seq_set.push_back(cur);
181+
seq_set_map[cur].push_back(i);
169182
}
170183

171-
seq_set.push_back(cur);
172-
seq_set_map[cur].push_back(i);
184+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
185+
if (seq_set_unq.test(s)) {
186+
seq_idx[s] = seq_id_unq.size();
187+
seq_id_unq.push_back(s);
188+
}
189+
}
173190
}
174191

175192
if (debug > 0) {
@@ -180,11 +197,14 @@ bool llama_batch_allocr::init(
180197
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
181198
/*.n_seq_tokens =*/ (uint32_t) 1,
182199
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
200+
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
183201
/*.token =*/ batch.token,
184202
/*.embd =*/ batch.embd,
185203
/*.pos =*/ batch.pos,
186204
/*.n_seq_id =*/ batch.n_seq_id,
187205
/*.seq_id =*/ batch.seq_id,
206+
/*.seq_id_unq =*/ this->seq_id_unq.data(),
207+
/*.seq_idx =*/ this->seq_idx.data(),
188208
/*.output =*/ batch.logits,
189209
};
190210

@@ -270,32 +290,44 @@ bool llama_batch_allocr::init(
270290
return true;
271291
}
272292

273-
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_tokens) {
293+
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
294+
const uint32_t n_tokens = n_seq_tokens*n_seqs;
295+
274296
clear();
275297
split_reset();
276298

277299
ubatches.emplace_back();
278300

279301
auto & ubatch = ubatches.back();
280302

281-
ubatch.token .resize(n_tokens);
282-
ubatch.embd .clear();
283-
ubatch.pos .resize(n_tokens);
284-
ubatch.n_seq_id.resize(n_tokens);
285-
ubatch.seq_id .resize(n_tokens);
286-
ubatch.output .resize(n_tokens);
303+
ubatch.token .resize(n_tokens);
304+
ubatch.embd .clear();
305+
ubatch.pos .resize(n_tokens);
306+
ubatch.n_seq_id .resize(n_tokens);
307+
ubatch.seq_id .resize(n_tokens);
308+
ubatch.seq_id_unq.resize(0);
309+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
310+
ubatch.output .resize(n_tokens);
311+
312+
for (uint32_t s = 0; s < n_seqs; ++s) {
313+
ubatch.seq_idx[s] = s;
314+
ubatch.seq_id_unq.push_back(s);
315+
}
287316

288317
llama_ubatch res {
289318
/*.equal_seqs =*/ true,
290319
/*.n_tokens =*/ n_tokens,
291-
/*.n_seq_tokens =*/ n_tokens,
292-
/*.n_seqs =*/ 1,
320+
/*.n_seq_tokens =*/ n_seq_tokens,
321+
/*.n_seqs =*/ n_seqs,
322+
/*.n_seqs_unq =*/ n_seqs,
293323

294324
/*.token =*/ ubatch.token.data(),
295325
/*.embd =*/ nullptr,
296326
/*.pos =*/ ubatch.pos.data(),
297327
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
298328
/*.seq_id =*/ ubatch.seq_id.data(),
329+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
330+
/*.seq_idx =*/ ubatch.seq_idx.data(),
299331
/*.output =*/ ubatch.output.data(),
300332
};
301333

@@ -489,10 +521,11 @@ void llama_batch_allocr::clear() {
489521

490522
batch = {};
491523

492-
pos .clear();
493-
n_seq_id.clear();
494-
seq_id .clear();
495-
output .clear();
524+
pos .clear();
525+
n_seq_id .clear();
526+
seq_id .clear();
527+
seq_id_unq.clear();
528+
output .clear();
496529

497530
for (auto & cur : seq_pos) {
498531
cur.clear();
@@ -516,12 +549,16 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
516549

517550
auto & ubatch = ubatches.back();
518551

519-
ubatch.token .resize(n_tokens);
520-
ubatch.embd .resize((int64_t) n_tokens*n_embd);
521-
ubatch.pos .resize(n_tokens);
522-
ubatch.n_seq_id.resize(n_tokens);
523-
ubatch.seq_id .resize(n_tokens);
524-
ubatch.output .resize(n_tokens);
552+
ubatch.token .resize(n_tokens);
553+
ubatch.embd .resize((int64_t) n_tokens*n_embd);
554+
ubatch.pos .resize(n_tokens);
555+
ubatch.n_seq_id .resize(n_tokens);
556+
ubatch.seq_id .resize(n_tokens);
557+
ubatch.seq_id_unq.resize(0);
558+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
559+
ubatch.output .resize(n_tokens);
560+
561+
seq_set_t seq_set_unq;
525562

526563
for (size_t i = 0; i < idxs.size(); ++i) {
527564
if (batch.token) {
@@ -537,22 +574,36 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
537574
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
538575
ubatch.output[i] = batch.logits[idxs[i]];
539576

577+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
578+
seq_set_unq.set(ubatch.seq_id[i][s]);
579+
}
580+
540581
if (ubatch.output[i]) {
541582
out_ids.push_back(idxs[i]);
542583
}
543584
}
544585

586+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
587+
if (seq_set_unq.test(s)) {
588+
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
589+
ubatch.seq_id_unq.push_back(s);
590+
}
591+
}
592+
545593
llama_ubatch res {
546594
/*.equal_seqs =*/ equal_seqs,
547595
/*.n_tokens =*/ n_tokens,
548596
/*.n_seq_tokens =*/ n_tokens/n_seqs,
549597
/*.n_seqs =*/ n_seqs,
598+
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
550599

551600
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
552601
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
553602
/*.pos =*/ ubatch.pos.data(),
554603
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
555604
/*.seq_id =*/ ubatch.seq_id.data(),
605+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
606+
/*.seq_idx =*/ ubatch.seq_idx.data(),
556607
/*.output =*/ ubatch.output.data(),
557608
};
558609

@@ -571,14 +622,38 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
571622
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
572623
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
573624
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
625+
LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
626+
627+
std::stringstream ss_seq_id_unq;
628+
std::stringstream ss_seq_idx;
629+
630+
ss_seq_id_unq << "[ ";
631+
ss_seq_idx << "[";
632+
633+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
634+
ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
635+
}
636+
637+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
638+
if (ubatch.seq_idx[s] >= 0) {
639+
ss_seq_idx << ubatch.seq_idx[s]%10;
640+
} else {
641+
ss_seq_idx << ".";
642+
}
643+
}
644+
645+
ss_seq_id_unq << "]";
646+
ss_seq_idx << "]";
574647

575-
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
576-
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
577-
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
578-
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
579-
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
580-
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
581-
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
648+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
649+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
650+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
651+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
652+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
653+
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
654+
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
655+
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
656+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
582657

583658
if (debug > 1) {
584659
int seq_id_max = 0;

src/llama-batch.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,18 @@ struct llama_ubatch {
1515
// TODO: whole_seqs for embeddings?
1616

1717
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
18-
uint32_t n_seq_tokens; // tokens per sequence
19-
uint32_t n_seqs;
20-
21-
llama_token * token; // [n_tokens]
22-
float * embd; // [n_embd, n_tokens]
23-
llama_pos * pos; // [n_tokens]
24-
int32_t * n_seq_id; // [n_tokens]
25-
llama_seq_id ** seq_id; // [n_tokens]
26-
int8_t * output; // [n_tokens]
18+
uint32_t n_seq_tokens; // tokens per sequence set
19+
uint32_t n_seqs; // sequence sets in the ubatch
20+
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
21+
22+
llama_token * token; // [n_tokens]
23+
float * embd; // [n_embd, n_tokens]
24+
llama_pos * pos; // [n_tokens]
25+
int32_t * n_seq_id; // [n_tokens]
26+
llama_seq_id ** seq_id; // [n_tokens]
27+
llama_seq_id * seq_id_unq; // [n_seqs_unq]
28+
int32_t * seq_idx; // [LLAMA_MAX_SEQ]
29+
int8_t * output; // [n_tokens]
2730
};
2831

2932
// a helper for sanitizing, fulfilling and splitting a batch
@@ -61,7 +64,8 @@ class llama_batch_allocr {
6164
// sequence-wise split - each ubatch contains a single sequence
6265
llama_ubatch split_seq(uint32_t n_ubatch);
6366

64-
llama_ubatch ubatch_reserve(uint32_t n_tokens);
67+
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
68+
6569
private:
6670
void clear();
6771

@@ -82,6 +86,8 @@ class llama_batch_allocr {
8286
std::vector<llama_pos> pos;
8387
std::vector<int32_t> n_seq_id;
8488
std::vector<llama_seq_id *> seq_id;
89+
std::vector<llama_seq_id> seq_id_unq;
90+
std::vector<int32_t> seq_idx;
8591
std::vector<int8_t> output;
8692

8793
using pos_set_t = std::set<llama_pos>;
@@ -108,6 +114,8 @@ class llama_batch_allocr {
108114
std::vector<llama_pos> pos;
109115
std::vector<int32_t> n_seq_id;
110116
std::vector<llama_seq_id *> seq_id;
117+
std::vector<llama_seq_id> seq_id_unq;
118+
std::vector<int32_t> seq_idx;
111119
std::vector<int8_t> output;
112120
};
113121

0 commit comments

Comments
 (0)