Skip to content

Commit 7eb5ad4

Browse files
committed
cont : log ubatches
ggml-ci
1 parent 060b5b2 commit 7eb5ad4

File tree

4 files changed

+99
-66
lines changed

4 files changed

+99
-66
lines changed

src/llama-batch.cpp

Lines changed: 90 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ bool llama_batch_allocr::init(
3030

3131
batch = batch_inp;
3232

33+
this->vocab = &vocab;
34+
3335
GGML_ASSERT(batch.n_tokens > 0);
3436

3537
//
@@ -172,67 +174,39 @@ bool llama_batch_allocr::init(
172174

173175
if (debug > 0) {
174176
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
175-
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens);
176-
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
177-
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
178-
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
179-
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
180-
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
181-
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
182-
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
183177

184-
if (debug > 1) {
185-
int seq_id_max = 0;
186-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
187-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
188-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
189-
seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]);
190-
}
191-
}
178+
llama_ubatch ubatch {
179+
/*.equal_seqs =*/ false,
180+
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
181+
/*.n_seq_tokens =*/ (uint32_t) 1,
182+
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
183+
/*.token =*/ batch.token,
184+
/*.embd =*/ batch.embd,
185+
/*.pos =*/ batch.pos,
186+
/*.n_seq_id =*/ batch.n_seq_id,
187+
/*.seq_id =*/ batch.seq_id,
188+
/*.output =*/ batch.logits,
189+
};
190+
191+
ubatch_print(ubatch, debug);
192+
193+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
194+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
195+
if (seq_pos[s0].empty()) {
196+
continue;
192197
}
193-
++seq_id_max;
194198

195-
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
196-
for (int32_t i = 0; i < batch.n_tokens; ++i) {
197-
std::vector<int8_t> seq_id(seq_id_max);
198-
199-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
200-
seq_id[batch.seq_id[i][s]] = 1;
201-
}
202-
203-
std::stringstream ss;
204-
for (int s = 0; s < seq_id_max; ++s) {
205-
if (seq_id[s]) {
206-
ss << s%10;
207-
} else {
208-
ss << ".";
209-
}
199+
std::stringstream ss;
200+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
201+
if (seq_cpl[s0][s1]) {
202+
ss << s1 << " ";
210203
}
211-
212-
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
213-
__func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(),
214-
batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]);
215204
}
216-
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
217-
218-
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
219-
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
220-
if (seq_pos[s0].empty()) {
221-
continue;
222-
}
223205

224-
std::stringstream ss;
225-
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
226-
if (seq_cpl[s0][s1]) {
227-
ss << s1 << " ";
228-
}
229-
}
230-
231-
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
232-
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
233-
}
234-
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
206+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
207+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
235208
}
209+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
236210
}
237211

238212
//
@@ -296,7 +270,7 @@ bool llama_batch_allocr::init(
296270
return true;
297271
}
298272

299-
llama_ubatch llama_batch_allocr::reserve_one(uint32_t n_tokens) {
273+
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_tokens) {
300274
clear();
301275
split_reset();
302276

@@ -389,7 +363,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
389363
}
390364
}
391365

392-
return add_ubatch(idxs, idxs.size(), false);
366+
return ubatch_add(idxs, idxs.size(), false);
393367
}
394368

395369
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
@@ -470,7 +444,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
470444
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
471445
}
472446

473-
return add_ubatch(idxs, n_seqs, true);
447+
return ubatch_add(idxs, n_seqs, true);
474448
}
475449

476450
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
@@ -507,7 +481,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
507481
cur_seq_set = seq_set[cur_idx];
508482
}
509483

510-
return add_ubatch(idxs, 1, true);
484+
return ubatch_add(idxs, 1, true);
511485
}
512486

513487
void llama_batch_allocr::clear() {
@@ -533,11 +507,9 @@ void llama_batch_allocr::clear() {
533507
seq_set_map.clear();
534508
}
535509

536-
llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
510+
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
537511
const uint32_t n_tokens = idxs.size();
538512

539-
LLAMA_LOG_DEBUG("add_ubatch: n_tokens = %d, n_seqs = %d, equal_seqs = %d", n_tokens, n_seqs, equal_seqs);
540-
541513
assert(n_tokens%n_seqs == 0);
542514

543515
ubatches.emplace_back();
@@ -584,11 +556,67 @@ llama_ubatch llama_batch_allocr::add_ubatch(const std::vector<int32_t> & idxs, u
584556
/*.output =*/ ubatch.output.data(),
585557
};
586558

587-
LLAMA_LOG_DEBUG("%s: added ubatch of size %d\n", __func__, res.n_tokens);
559+
LLAMA_LOG_DEBUG("%s: added ubatch %d in split\n", __func__, (int) ubatches.size() - 1);
560+
561+
if (debug > 0) {
562+
ubatch_print(res, debug);
563+
}
588564

589565
return res;
590566
}
591567

568+
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
569+
if (debug > 0) {
570+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
571+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
572+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
573+
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
574+
575+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
576+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
577+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
578+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
579+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
580+
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
581+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
582+
583+
if (debug > 1) {
584+
int seq_id_max = 0;
585+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
586+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
587+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
588+
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
589+
}
590+
}
591+
}
592+
++seq_id_max;
593+
594+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
595+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
596+
std::vector<int8_t> seq_id(seq_id_max);
597+
598+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
599+
seq_id[ubatch.seq_id[i][s]] = 1;
600+
}
601+
602+
std::stringstream ss;
603+
for (int s = 0; s < seq_id_max; ++s) {
604+
if (seq_id[s]) {
605+
ss << s%10;
606+
} else {
607+
ss << ".";
608+
}
609+
}
610+
611+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
612+
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
613+
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
614+
}
615+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
616+
}
617+
}
618+
}
619+
592620
//
593621
// interface implementation
594622
//

src/llama-batch.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,19 @@ class llama_batch_allocr {
6161
// sequence-wise split - each ubatch contains a single sequence
6262
llama_ubatch split_seq(uint32_t n_ubatch);
6363

64-
llama_ubatch reserve_one(uint32_t n_tokens);
64+
llama_ubatch ubatch_reserve(uint32_t n_tokens);
6565
private:
6666
void clear();
6767

68-
llama_ubatch add_ubatch(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
68+
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
69+
70+
void ubatch_print(const llama_ubatch & ubatch, int debug);
6971

7072
llama_batch batch;
7173

74+
// only for debugging purposes
75+
const llama_vocab * vocab;
76+
7277
uint32_t n_embd;
7378
uint32_t n_outputs;
7479

src/llama-kv-cache-recurrent.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
826826

827827
llama_batch_allocr batch_allocr;
828828

829-
llama_ubatch ubatch = batch_allocr.reserve_one(cell_count);
829+
llama_ubatch ubatch = batch_allocr.ubatch_reserve(cell_count);
830830

831831
ubatch.n_tokens = cell_count;
832832
ubatch.n_seq_tokens = cell_count;

src/llama-kv-cache-unified.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15071507

15081508
llama_batch_allocr batch_allocr;
15091509

1510-
llama_ubatch ubatch = batch_allocr.reserve_one(cell_count);
1510+
llama_ubatch ubatch = batch_allocr.ubatch_reserve(cell_count);
15111511

15121512
ubatch.n_tokens = cell_count;
15131513
ubatch.n_seq_tokens = cell_count;

0 commit comments

Comments
 (0)