@@ -18,6 +18,8 @@ llama_batch_allocr::llama_batch_allocr() {
18
18
for (auto & cur : seq_cpl) {
19
19
cur.resize (LLAMA_MAX_SEQ);
20
20
}
21
+
22
+ seq_idx.resize (LLAMA_MAX_SEQ, -1 );
21
23
}
22
24
23
25
bool llama_batch_allocr::init (
@@ -137,22 +139,23 @@ bool llama_batch_allocr::init(
137
139
// compute stats
138
140
//
139
141
142
+ this ->n_embd = n_embd;
143
+
140
144
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
141
145
n_outputs += batch.logits [i] != 0 ;
142
146
}
143
147
144
- this ->n_embd = n_embd;
145
-
146
148
// determine coupled sequences
147
149
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
148
150
for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
151
+ const llama_seq_id s0 = batch.seq_id [i][0 ];
152
+
149
153
for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
150
- seq_pos[ batch.seq_id [i][s]]. insert (batch. pos [i]) ;
154
+ const llama_seq_id s1 = batch.seq_id [i][s];
151
155
152
- if (s > 0 ) {
153
- const llama_seq_id s0 = batch.seq_id [i][0 ];
154
- const llama_seq_id s1 = batch.seq_id [i][s];
156
+ seq_pos[s1].insert (batch.pos [i]);
155
157
158
+ if (s > 0 ) {
156
159
// mark that sequence s1 is coupled to s0
157
160
seq_cpl[s1][s0] = true ;
158
161
@@ -162,14 +165,28 @@ bool llama_batch_allocr::init(
162
165
}
163
166
}
164
167
165
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
166
- seq_set_t cur;
167
- for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
168
- cur.set (batch.seq_id [i][s]);
168
+ {
169
+ seq_set_t seq_set_unq;
170
+
171
+ for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
172
+ seq_set_t cur;
173
+ for (int32_t s = 0 ; s < batch.n_seq_id [i]; ++s) {
174
+ const llama_seq_id s0 = batch.seq_id [i][s];
175
+
176
+ cur.set (s0);
177
+ seq_set_unq.set (s0);
178
+ }
179
+
180
+ seq_set.push_back (cur);
181
+ seq_set_map[cur].push_back (i);
169
182
}
170
183
171
- seq_set.push_back (cur);
172
- seq_set_map[cur].push_back (i);
184
+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
185
+ if (seq_set_unq.test (s)) {
186
+ seq_idx[s] = seq_id_unq.size ();
187
+ seq_id_unq.push_back (s);
188
+ }
189
+ }
173
190
}
174
191
175
192
if (debug > 0 ) {
@@ -180,11 +197,14 @@ bool llama_batch_allocr::init(
180
197
/* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
181
198
/* .n_seq_tokens =*/ (uint32_t ) 1 ,
182
199
/* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
200
+ /* .n_seqs_unq =*/ (uint32_t ) this ->seq_id_unq .size (),
183
201
/* .token =*/ batch.token ,
184
202
/* .embd =*/ batch.embd ,
185
203
/* .pos =*/ batch.pos ,
186
204
/* .n_seq_id =*/ batch.n_seq_id ,
187
205
/* .seq_id =*/ batch.seq_id ,
206
+ /* .seq_id_unq =*/ this ->seq_id_unq .data (),
207
+ /* .seq_idx =*/ this ->seq_idx .data (),
188
208
/* .output =*/ batch.logits ,
189
209
};
190
210
@@ -270,32 +290,44 @@ bool llama_batch_allocr::init(
270
290
return true ;
271
291
}
272
292
273
- llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_tokens) {
293
+ llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_seq_tokens, uint32_t n_seqs) {
294
+ const uint32_t n_tokens = n_seq_tokens*n_seqs;
295
+
274
296
clear ();
275
297
split_reset ();
276
298
277
299
ubatches.emplace_back ();
278
300
279
301
auto & ubatch = ubatches.back ();
280
302
281
- ubatch.token .resize (n_tokens);
282
- ubatch.embd .clear ();
283
- ubatch.pos .resize (n_tokens);
284
- ubatch.n_seq_id .resize (n_tokens);
285
- ubatch.seq_id .resize (n_tokens);
286
- ubatch.output .resize (n_tokens);
303
+ ubatch.token .resize (n_tokens);
304
+ ubatch.embd .clear ();
305
+ ubatch.pos .resize (n_tokens);
306
+ ubatch.n_seq_id .resize (n_tokens);
307
+ ubatch.seq_id .resize (n_tokens);
308
+ ubatch.seq_id_unq .resize (0 );
309
+ ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
310
+ ubatch.output .resize (n_tokens);
311
+
312
+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
313
+ ubatch.seq_idx [s] = s;
314
+ ubatch.seq_id_unq .push_back (s);
315
+ }
287
316
288
317
llama_ubatch res {
289
318
/* .equal_seqs =*/ true ,
290
319
/* .n_tokens =*/ n_tokens,
291
- /* .n_seq_tokens =*/ n_tokens,
292
- /* .n_seqs =*/ 1 ,
320
+ /* .n_seq_tokens =*/ n_seq_tokens,
321
+ /* .n_seqs =*/ n_seqs,
322
+ /* .n_seqs_unq =*/ n_seqs,
293
323
294
324
/* .token =*/ ubatch.token .data (),
295
325
/* .embd =*/ nullptr ,
296
326
/* .pos =*/ ubatch.pos .data (),
297
327
/* .n_seq_id =*/ ubatch.n_seq_id .data (),
298
328
/* .seq_id =*/ ubatch.seq_id .data (),
329
+ /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
330
+ /* .seq_idx =*/ ubatch.seq_idx .data (),
299
331
/* .output =*/ ubatch.output .data (),
300
332
};
301
333
@@ -489,10 +521,11 @@ void llama_batch_allocr::clear() {
489
521
490
522
batch = {};
491
523
492
- pos .clear ();
493
- n_seq_id.clear ();
494
- seq_id .clear ();
495
- output .clear ();
524
+ pos .clear ();
525
+ n_seq_id .clear ();
526
+ seq_id .clear ();
527
+ seq_id_unq.clear ();
528
+ output .clear ();
496
529
497
530
for (auto & cur : seq_pos) {
498
531
cur.clear ();
@@ -516,12 +549,16 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
516
549
517
550
auto & ubatch = ubatches.back ();
518
551
519
- ubatch.token .resize (n_tokens);
520
- ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
521
- ubatch.pos .resize (n_tokens);
522
- ubatch.n_seq_id .resize (n_tokens);
523
- ubatch.seq_id .resize (n_tokens);
524
- ubatch.output .resize (n_tokens);
552
+ ubatch.token .resize (n_tokens);
553
+ ubatch.embd .resize ((int64_t ) n_tokens*n_embd);
554
+ ubatch.pos .resize (n_tokens);
555
+ ubatch.n_seq_id .resize (n_tokens);
556
+ ubatch.seq_id .resize (n_tokens);
557
+ ubatch.seq_id_unq .resize (0 );
558
+ ubatch.seq_idx .resize (LLAMA_MAX_SEQ, -1 );
559
+ ubatch.output .resize (n_tokens);
560
+
561
+ seq_set_t seq_set_unq;
525
562
526
563
for (size_t i = 0 ; i < idxs.size (); ++i) {
527
564
if (batch.token ) {
@@ -537,22 +574,36 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
537
574
ubatch.seq_id [i] = batch.seq_id [idxs[i]];
538
575
ubatch.output [i] = batch.logits [idxs[i]];
539
576
577
+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
578
+ seq_set_unq.set (ubatch.seq_id [i][s]);
579
+ }
580
+
540
581
if (ubatch.output [i]) {
541
582
out_ids.push_back (idxs[i]);
542
583
}
543
584
}
544
585
586
+ for (int32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
587
+ if (seq_set_unq.test (s)) {
588
+ ubatch.seq_idx [s] = ubatch.seq_id_unq .size ();
589
+ ubatch.seq_id_unq .push_back (s);
590
+ }
591
+ }
592
+
545
593
llama_ubatch res {
546
594
/* .equal_seqs =*/ equal_seqs,
547
595
/* .n_tokens =*/ n_tokens,
548
596
/* .n_seq_tokens =*/ n_tokens/n_seqs,
549
597
/* .n_seqs =*/ n_seqs,
598
+ /* .n_seqs_unq =*/ (uint32_t ) ubatch.seq_id_unq .size (),
550
599
551
600
/* .token =*/ batch.token ? ubatch.token .data () : nullptr ,
552
601
/* .embd =*/ batch.embd ? ubatch.embd .data () : nullptr ,
553
602
/* .pos =*/ ubatch.pos .data (),
554
603
/* .n_seq_id =*/ ubatch.n_seq_id .data (),
555
604
/* .seq_id =*/ ubatch.seq_id .data (),
605
+ /* .seq_id_unq =*/ ubatch.seq_id_unq .data (),
606
+ /* .seq_idx =*/ ubatch.seq_idx .data (),
556
607
/* .output =*/ ubatch.output .data (),
557
608
};
558
609
@@ -571,14 +622,38 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
571
622
LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
572
623
LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
573
624
LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
625
+ LLAMA_LOG_DEBUG (" %s: n_seqs_unq = %d\n " , __func__, ubatch.n_seqs_unq );
626
+
627
+ std::stringstream ss_seq_id_unq;
628
+ std::stringstream ss_seq_idx;
629
+
630
+ ss_seq_id_unq << " [ " ;
631
+ ss_seq_idx << " [" ;
632
+
633
+ for (uint32_t s = 0 ; s < ubatch.n_seqs_unq ; ++s) {
634
+ ss_seq_id_unq << ubatch.seq_id_unq [s] << " " ;
635
+ }
636
+
637
+ for (uint32_t s = 0 ; s < LLAMA_MAX_SEQ; ++s) {
638
+ if (ubatch.seq_idx [s] >= 0 ) {
639
+ ss_seq_idx << ubatch.seq_idx [s]%10 ;
640
+ } else {
641
+ ss_seq_idx << " ." ;
642
+ }
643
+ }
644
+
645
+ ss_seq_id_unq << " ]" ;
646
+ ss_seq_idx << " ]" ;
574
647
575
- LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
576
- LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
577
- LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
578
- LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
579
- LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
580
- LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
581
- LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
648
+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
649
+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
650
+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
651
+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
652
+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
653
+ LLAMA_LOG_DEBUG (" %s: seq_id_unq = %s\n " , __func__, ss_seq_id_unq.str ().c_str ());
654
+ LLAMA_LOG_DEBUG (" %s: seq_idx = %s\n " , __func__, ss_seq_idx.str ().c_str ());
655
+ LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
656
+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
582
657
583
658
if (debug > 1 ) {
584
659
int seq_id_max = 0 ;
0 commit comments