@@ -30,6 +30,8 @@ bool llama_batch_allocr::init(
30
30
31
31
batch = batch_inp;
32
32
33
+ this ->vocab = &vocab;
34
+
33
35
GGML_ASSERT (batch.n_tokens > 0 );
34
36
35
37
//
@@ -172,67 +174,39 @@ bool llama_batch_allocr::init(
172
174
173
175
if (debug > 0 ) {
174
176
LLAMA_LOG_DEBUG (" %s: input batch info:\n " , __func__);
175
- LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, batch.n_tokens );
176
- LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) batch.token );
177
- LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) batch.embd );
178
- LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) batch.pos );
179
- LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) batch.n_seq_id );
180
- LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) batch.seq_id );
181
- LLAMA_LOG_DEBUG (" %s: logits = %p\n " , __func__, (void *) batch.logits );
182
- LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
183
177
184
- if (debug > 1 ) {
185
- int seq_id_max = 0 ;
186
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
187
- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
188
- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
189
- seq_id_max = std::max (seq_id_max, batch.seq_id [i][s]);
190
- }
191
- }
178
+ llama_ubatch ubatch {
179
+ /* .equal_seqs =*/ false ,
180
+ /* .n_tokens =*/ (uint32_t ) batch.n_tokens ,
181
+ /* .n_seq_tokens =*/ (uint32_t ) 1 ,
182
+ /* .n_seqs =*/ (uint32_t ) batch.n_tokens ,
183
+ /* .token =*/ batch.token ,
184
+ /* .embd =*/ batch.embd ,
185
+ /* .pos =*/ batch.pos ,
186
+ /* .n_seq_id =*/ batch.n_seq_id ,
187
+ /* .seq_id =*/ batch.seq_id ,
188
+ /* .output =*/ batch.logits ,
189
+ };
190
+
191
+ ubatch_print (ubatch, debug);
192
+
193
+ LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
194
+ for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
195
+ if (seq_pos[s0].empty ()) {
196
+ continue ;
192
197
}
193
- ++seq_id_max;
194
198
195
- LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
196
- for (int32_t i = 0 ; i < batch.n_tokens ; ++i) {
197
- std::vector<int8_t > seq_id (seq_id_max);
198
-
199
- for (int s = 0 ; s < batch.n_seq_id [i]; ++s) {
200
- seq_id[batch.seq_id [i][s]] = 1 ;
201
- }
202
-
203
- std::stringstream ss;
204
- for (int s = 0 ; s < seq_id_max; ++s) {
205
- if (seq_id[s]) {
206
- ss << s%10 ;
207
- } else {
208
- ss << " ." ;
209
- }
199
+ std::stringstream ss;
200
+ for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
201
+ if (seq_cpl[s0][s1]) {
202
+ ss << s1 << " " ;
210
203
}
211
-
212
- LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
213
- __func__, i, batch.token [i], vocab.token_to_piece (batch.token [i]).c_str (),
214
- batch.pos [i], batch.n_seq_id [i], ss.str ().c_str (), batch.logits [i]);
215
204
}
216
- LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
217
-
218
- LLAMA_LOG_DEBUG (" %s: seq = [\n " , __func__);
219
- for (int s0 = 0 ; s0 < (int ) seq_pos.size (); ++s0) {
220
- if (seq_pos[s0].empty ()) {
221
- continue ;
222
- }
223
205
224
- std::stringstream ss;
225
- for (int s1 = 0 ; s1 < (int ) seq_cpl[s0].size (); ++s1) {
226
- if (seq_cpl[s0][s1]) {
227
- ss << s1 << " " ;
228
- }
229
- }
230
-
231
- LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
232
- __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
233
- }
234
- LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
206
+ LLAMA_LOG_DEBUG (" %s: %4d: pos = [%4d, %4d], cpl = %s\n " ,
207
+ __func__, s0, seq_pos_min (s0), seq_pos_max (s0), ss.str ().empty () ? " -" : ss.str ().c_str ());
235
208
}
209
+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
236
210
}
237
211
238
212
//
@@ -296,7 +270,7 @@ bool llama_batch_allocr::init(
296
270
return true ;
297
271
}
298
272
299
- llama_ubatch llama_batch_allocr::reserve_one (uint32_t n_tokens) {
273
+ llama_ubatch llama_batch_allocr::ubatch_reserve (uint32_t n_tokens) {
300
274
clear ();
301
275
split_reset ();
302
276
@@ -389,7 +363,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
389
363
}
390
364
}
391
365
392
- return add_ubatch (idxs, idxs.size (), false );
366
+ return ubatch_add (idxs, idxs.size (), false );
393
367
}
394
368
395
369
llama_ubatch llama_batch_allocr::split_equal (uint32_t n_ubatch) {
@@ -470,7 +444,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470
444
idxs.insert (idxs.end (), idxs_per_seq[s].begin (), idxs_per_seq[s].end ());
471
445
}
472
446
473
- return add_ubatch (idxs, n_seqs, true );
447
+ return ubatch_add (idxs, n_seqs, true );
474
448
}
475
449
476
450
llama_ubatch llama_batch_allocr::split_seq (uint32_t n_ubatch) {
@@ -507,7 +481,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
507
481
cur_seq_set = seq_set[cur_idx];
508
482
}
509
483
510
- return add_ubatch (idxs, 1 , true );
484
+ return ubatch_add (idxs, 1 , true );
511
485
}
512
486
513
487
void llama_batch_allocr::clear () {
@@ -533,11 +507,9 @@ void llama_batch_allocr::clear() {
533
507
seq_set_map.clear ();
534
508
}
535
509
536
- llama_ubatch llama_batch_allocr::add_ubatch (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs) {
510
+ llama_ubatch llama_batch_allocr::ubatch_add (const std::vector<int32_t > & idxs, uint32_t n_seqs, bool equal_seqs) {
537
511
const uint32_t n_tokens = idxs.size ();
538
512
539
- LLAMA_LOG_DEBUG (" add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d" , n_tokens, n_seqs, equal_seqs);
540
-
541
513
assert (n_tokens%n_seqs == 0 );
542
514
543
515
ubatches.emplace_back ();
@@ -584,11 +556,67 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
584
556
/* .output =*/ ubatch.output .data (),
585
557
};
586
558
587
- LLAMA_LOG_DEBUG (" %s: added ubatch of size %d\n " , __func__, res.n_tokens );
559
+ LLAMA_LOG_DEBUG (" %s: added ubatch %d in split\n " , __func__, (int ) ubatches.size () - 1 );
560
+
561
+ if (debug > 0 ) {
562
+ ubatch_print (res, debug);
563
+ }
588
564
589
565
return res;
590
566
}
591
567
568
+ void llama_batch_allocr::ubatch_print (const llama_ubatch & ubatch, int debug) {
569
+ if (debug > 0 ) {
570
+ LLAMA_LOG_DEBUG (" %s: equal_seqs = %d\n " , __func__, ubatch.equal_seqs );
571
+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d\n " , __func__, ubatch.n_tokens );
572
+ LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d\n " , __func__, ubatch.n_seq_tokens );
573
+ LLAMA_LOG_DEBUG (" %s: n_seqs = %d\n " , __func__, ubatch.n_seqs );
574
+
575
+ LLAMA_LOG_DEBUG (" %s: token = %p\n " , __func__, (void *) ubatch.token );
576
+ LLAMA_LOG_DEBUG (" %s: embd = %p\n " , __func__, (void *) ubatch.embd );
577
+ LLAMA_LOG_DEBUG (" %s: pos = %p\n " , __func__, (void *) ubatch.pos );
578
+ LLAMA_LOG_DEBUG (" %s: n_seq_id = %p\n " , __func__, (void *) ubatch.n_seq_id );
579
+ LLAMA_LOG_DEBUG (" %s: seq_id = %p\n " , __func__, (void *) ubatch.seq_id );
580
+ LLAMA_LOG_DEBUG (" %s: output = %p\n " , __func__, (void *) ubatch.output );
581
+ LLAMA_LOG_DEBUG (" %s: n_outputs = %d\n " , __func__, n_outputs);
582
+
583
+ if (debug > 1 ) {
584
+ int seq_id_max = 0 ;
585
+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
586
+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
587
+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
588
+ seq_id_max = std::max (seq_id_max, ubatch.seq_id [i][s]);
589
+ }
590
+ }
591
+ }
592
+ ++seq_id_max;
593
+
594
+ LLAMA_LOG_DEBUG (" %s: token = [\n " , __func__);
595
+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i) {
596
+ std::vector<int8_t > seq_id (seq_id_max);
597
+
598
+ for (int s = 0 ; s < ubatch.n_seq_id [i]; ++s) {
599
+ seq_id[ubatch.seq_id [i][s]] = 1 ;
600
+ }
601
+
602
+ std::stringstream ss;
603
+ for (int s = 0 ; s < seq_id_max; ++s) {
604
+ if (seq_id[s]) {
605
+ ss << s%10 ;
606
+ } else {
607
+ ss << " ." ;
608
+ }
609
+ }
610
+
611
+ LLAMA_LOG_DEBUG (" %s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n " ,
612
+ __func__, i, ubatch.token [i], vocab->token_to_piece (ubatch.token [i]).c_str (),
613
+ ubatch.pos [i], ubatch.n_seq_id [i], ss.str ().c_str (), ubatch.output [i]);
614
+ }
615
+ LLAMA_LOG_DEBUG (" %s: ]\n " , __func__);
616
+ }
617
+ }
618
+ }
619
+
592
620
//
593
621
// interface implementation
594
622
//
0 commit comments