Skip to content

Commit deb5328

Browse files
committed
ubatch : new splitting logic
ggml-ci
1 parent edc4a29 commit deb5328

19 files changed

+991
-914
lines changed

src/llama-batch.cpp

Lines changed: 557 additions & 359 deletions
Large diffs are not rendered by default.

src/llama-batch.h

Lines changed: 98 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,118 +2,146 @@
22

33
#include "llama.h"
44

5+
#include "llama-cparams.h"
6+
57
#include <array>
68
#include <vector>
79
#include <set>
10+
#include <bitset>
11+
#include <unordered_map>
812

9-
// very similar to llama_batch,
10-
// but has more metadata about sequences
13+
// keep this struct lightweight
14+
// it points to data in `llama_batch_allocr`
1115
struct llama_ubatch {
1216
bool equal_seqs;
1317
// TODO: whole_seqs for embeddings?
1418

1519
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
16-
uint32_t n_seq_tokens; // tokens per sequence
17-
uint32_t n_seqs;
18-
19-
llama_token * token; // [n_tokens]
20-
float * embd; // [n_embd, n_tokens]
21-
llama_pos * pos; // [n_tokens]
22-
int32_t * n_seq_id; // [n_seqs]
23-
llama_seq_id ** seq_id; // [n_seqs]
24-
int8_t * output; // [n_tokens]
25-
};
26-
27-
struct llama_sbatch_seq {
28-
int32_t n_seq_id;
29-
30-
llama_seq_id * seq_id;
31-
32-
size_t offset;
33-
size_t length;
34-
};
35-
36-
// sequence-length-aware batch splitting
37-
struct llama_sbatch {
38-
// tokens left in this batch
39-
size_t n_tokens;
40-
41-
size_t n_embd;
42-
43-
// sorted indices into the batch
44-
std::vector<int64_t> ids;
45-
// batch indices of the output
46-
std::vector<int64_t> out_ids;
47-
std::vector<llama_sbatch_seq> seq;
48-
49-
const llama_batch * batch = nullptr;
50-
51-
// buffers for the ubatches
52-
// TODO: very hacky, this needs a complete rework
53-
struct ubatch_data {
54-
std::vector<llama_token> token;
55-
std::vector<float> embd;
56-
std::vector<llama_pos> pos;
57-
std::vector<int32_t> n_seq_id;
58-
std::vector<llama_seq_id *> seq_id;
59-
std::vector<int8_t> output;
60-
};
61-
62-
std::vector<ubatch_data> udatas;
63-
64-
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
65-
66-
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
67-
68-
// simple split, unknown number of sequences of unequal lengths
69-
llama_ubatch split_simple(size_t n_ubatch);
70-
71-
// make batches of equal-length sequences
72-
llama_ubatch split_equal(size_t n_ubatch);
73-
74-
// sequence-wise split
75-
llama_ubatch split_seq(size_t n_ubatch);
76-
77-
llama_sbatch() = default;
78-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
20+
uint32_t n_seq_tokens; // tokens per sequence set
21+
uint32_t n_seqs; // sequence sets in the ubatch
22+
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
23+
24+
// seq_id_unq: unique sequence ids in the ubatch
25+
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
26+
// used for extracting sequence pooled embeddings
27+
28+
// // size | idx | val
29+
llama_token * token; // [n_tokens] | i | id, token
30+
float * embd; // [n_embd, n_tokens] | i | embd
31+
llama_pos * pos; // [n_tokens] | i | pos
32+
int32_t * n_seq_id; // [n_tokens] | i | -
33+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
34+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
35+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
36+
int8_t * output; // [n_tokens] | i | -
7937
};
8038

81-
// a helper for sanitizing and fulfilling a batch
39+
// a helper for sanitizing, fulfilling and splitting a batch
8240
class llama_batch_allocr {
8341
public:
84-
llama_batch_allocr();
42+
llama_batch_allocr(uint32_t n_pos_per_embd);
8543

8644
// sanitize and auto-gen missing data in the input batch
8745
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
8846
bool init(
8947
const llama_batch & batch_inp,
9048
const llama_vocab & vocab,
9149
const llama_memory_i * memory,
92-
bool embd_all);
50+
uint32_t n_embd,
51+
bool output_all);
9352

9453
const llama_batch & get_batch() const;
9554

55+
uint32_t get_n_tokens() const;
9656
uint32_t get_n_outputs() const;
9757

58+
// the array of output indices in the order they were encountered during the ubatch splitting
59+
std::vector<int32_t> & get_out_ids();
60+
61+
// min/max positions of each sequence in the current ubatch
9862
llama_pos seq_pos_min(llama_seq_id seq_id) const;
9963
llama_pos seq_pos_max(llama_seq_id seq_id) const;
10064

65+
// call once before splitting the batch to reset the internal state
66+
void split_reset();
67+
68+
// simple split, unknown number of sequence sets of unequal lengths
69+
llama_ubatch split_simple(uint32_t n_ubatch);
70+
71+
// make ubatches of equal-length sequences sets
72+
llama_ubatch split_equal(uint32_t n_ubatch);
73+
74+
// sequence-set-wise split - each ubatch contains a single sequence-set
75+
llama_ubatch split_seq(uint32_t n_ubatch);
76+
77+
// a helper method for creating a well-defined ubatch of tokens
78+
// TODO: support embeddings if needed in the future
79+
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
80+
10181
private:
10282
void clear();
10383

84+
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
85+
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
86+
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
87+
88+
// for debugging, start with LLAMA_BATCH_DEBUG=2
89+
void ubatch_print(const llama_ubatch & ubatch, int debug);
90+
10491
llama_batch batch;
10592

93+
// only for debugging purposes
94+
const llama_vocab * vocab;
95+
96+
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
97+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
98+
const uint32_t n_pos_per_embd;
99+
100+
uint32_t n_embd;
106101
uint32_t n_outputs;
107102

108103
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
109104

110105
std::vector<llama_pos> pos;
111106
std::vector<int32_t> n_seq_id;
112107
std::vector<llama_seq_id *> seq_id;
108+
std::vector<llama_seq_id> seq_id_unq;
109+
std::vector<int32_t> seq_idx;
113110
std::vector<int8_t> output;
114111

115-
std::vector<std::set<llama_pos>> seq_pos; // seq_pos[s]: the set of positions in sequence s
116-
std::vector<std::vector<bool>> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
112+
using pos_set_t = std::set<llama_pos>;
113+
using seq_cpl_t = std::vector<bool>;
114+
115+
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116+
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117+
118+
using idx_vec_t = std::vector<int32_t>;
119+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
120+
121+
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
122+
123+
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
124+
125+
// batch indices of the output
126+
std::vector<int32_t> out_ids;
127+
128+
// used[i] indicates if token i has already been used in a previous ubatch
129+
std::vector<bool> used;
130+
131+
// llama_ubatch points to this data:
132+
struct ubatch {
133+
std::vector<llama_token> token;
134+
std::vector<float> embd;
135+
std::vector<llama_pos> pos;
136+
std::vector<int32_t> n_seq_id;
137+
std::vector<llama_seq_id *> seq_id;
138+
std::vector<llama_seq_id> seq_id_unq;
139+
std::vector<int32_t> seq_idx;
140+
std::vector<int8_t> output;
141+
};
142+
143+
// current splitting state:
144+
std::vector<ubatch> ubatches;
117145

118146
int debug;
119147
};

0 commit comments

Comments
 (0)