Skip to content

Commit c9f670a

Browse files
thementjxhor
andauthored
Implement non-greedy tokenizer that tries to maximize token lengths (#242)
* Implement non-greedy tokenizer that tries to maximize token lengths * Insert single space in front of the prompt - this is to match original llama tokenizer behavior --------- Co-authored-by: Jakub Horak <[email protected]>
1 parent 4f54609 commit c9f670a

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

main.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,8 @@ int main(int argc, char ** argv) {
845845

846846
std::vector<float> logits;
847847

848+
// Add a space in front of the first character to match OG llama tokenizer behavior
849+
params.prompt.insert(0, 1, ' ');
848850
// tokenize the prompt
849851
std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize(vocab, params.prompt, true);
850852

utils.cpp

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -287,41 +287,57 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
287287
return tokens;
288288
}
289289

290+
// TODO: Calculate this constant from the vocabulary
291+
#define MAX_TOKEN_LEN 18
292+
// SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
290293
std::vector<gpt_vocab::id> llama_tokenize(const gpt_vocab & vocab, const std::string & text, bool bos) {
291-
//auto res = gpt_tokenize(vocab, text);
292-
293-
//if (bos) {
294-
// res.insert(res.begin(), 1); // TODO: replace with vocab.bos
295-
//}
296-
297294
std::vector<gpt_vocab::id> res;
298-
299-
if (bos) {
300-
res.push_back(1); // TODO: replace with vocab.bos
301-
}
302-
303-
//find the longest token that matches the text
304-
int pos = 0;
305-
while (true) {
306-
int l = 0;
307-
int t = 0;
308-
for (const auto & kv : vocab.id_to_token) {
309-
if (kv.second.size() < l) continue;
310-
if (kv.second.size() > text.size() - pos) continue;
311-
if (text.substr(pos, kv.second.size()) == kv.second) {
312-
l = kv.second.size();
313-
t = kv.first;
295+
std::vector<int> score;
296+
std::vector<gpt_vocab::id> prev;
297+
int len = text.length();
298+
299+
score.resize(len + 1);
300+
prev.resize(len + 1);
301+
302+
// Forward pass
303+
for (int i = 0; i < len; i++) {
304+
int max_len = std::min(len - i, MAX_TOKEN_LEN);
305+
for (int sub_len = 1; sub_len <= len - i; sub_len++) {
306+
auto sub = text.substr(i, sub_len);
307+
auto token = vocab.token_to_id.find(sub);
308+
if (token != vocab.token_to_id.end()) {
309+
int token_score = sub.length() * sub.length();
310+
int local_score = score[i] + token_score;
311+
int next = i + sub_len;
312+
if (score[next] < local_score) {
313+
score[next] = local_score;
314+
prev[next] = (*token).second;
315+
}
314316
}
315317
}
318+
}
316319

317-
if (l == 0) {
318-
break;
320+
// Backward pass
321+
int i = len;
322+
while (i > 0) {
323+
gpt_vocab::id token_id = prev[i];
324+
if (token_id == 0) {
325+
// TODO: Return error or something more meaningful
326+
printf("failed to tokenize string!\n");
327+
break;
319328
}
329+
res.push_back(token_id);
330+
auto token = (*vocab.id_to_token.find(token_id)).second;
331+
i -= token.length();
332+
}
320333

321-
res.push_back(t);
322-
pos += l;
334+
if (bos) {
335+
res.push_back(1); // TODO: replace with vocab.bos
323336
}
324337

338+
// Pieces are in reverse order so correct that
339+
std::reverse(res.begin(), res.end());
340+
325341
return res;
326342
}
327343

0 commit comments

Comments
 (0)