@@ -287,41 +287,57 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
287
287
return tokens;
288
288
}
289
289
290
+ // TODO: Calculate this constant from the vocabulary
291
+ #define MAX_TOKEN_LEN 18
292
+ // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
290
293
std::vector<gpt_vocab::id> llama_tokenize (const gpt_vocab & vocab, const std::string & text, bool bos) {
291
- // auto res = gpt_tokenize(vocab, text);
292
-
293
- // if (bos) {
294
- // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
295
- // }
296
-
297
294
std::vector<gpt_vocab::id> res;
298
-
299
- if (bos) {
300
- res.push_back (1 ); // TODO: replace with vocab.bos
301
- }
302
-
303
- // find the longest token that matches the text
304
- int pos = 0 ;
305
- while (true ) {
306
- int l = 0 ;
307
- int t = 0 ;
308
- for (const auto & kv : vocab.id_to_token ) {
309
- if (kv.second .size () < l) continue ;
310
- if (kv.second .size () > text.size () - pos) continue ;
311
- if (text.substr (pos, kv.second .size ()) == kv.second ) {
312
- l = kv.second .size ();
313
- t = kv.first ;
295
+ std::vector<int > score;
296
+ std::vector<gpt_vocab::id> prev;
297
+ int len = text.length ();
298
+
299
+ score.resize (len + 1 );
300
+ prev.resize (len + 1 );
301
+
302
+ // Forward pass
303
+ for (int i = 0 ; i < len; i++) {
304
+ int max_len = std::min (len - i, MAX_TOKEN_LEN);
305
+ for (int sub_len = 1 ; sub_len <= len - i; sub_len++) {
306
+ auto sub = text.substr (i, sub_len);
307
+ auto token = vocab.token_to_id .find (sub);
308
+ if (token != vocab.token_to_id .end ()) {
309
+ int token_score = sub.length () * sub.length ();
310
+ int local_score = score[i] + token_score;
311
+ int next = i + sub_len;
312
+ if (score[next] < local_score) {
313
+ score[next] = local_score;
314
+ prev[next] = (*token).second ;
315
+ }
314
316
}
315
317
}
318
+ }
316
319
317
- if (l == 0 ) {
318
- break ;
320
+ // Backward pass
321
+ int i = len;
322
+ while (i > 0 ) {
323
+ gpt_vocab::id token_id = prev[i];
324
+ if (token_id == 0 ) {
325
+ // TODO: Return error or something more meaningful
326
+ printf (" failed to tokenize string!\n " );
327
+ break ;
319
328
}
329
+ res.push_back (token_id);
330
+ auto token = (*vocab.id_to_token .find (token_id)).second ;
331
+ i -= token.length ();
332
+ }
320
333
321
- res. push_back (t);
322
- pos += l;
334
+ if (bos) {
335
+ res. push_back ( 1 ); // TODO: replace with vocab.bos
323
336
}
324
337
338
+ // Pieces are in reverse order so correct that
339
+ std::reverse (res.begin (), res.end ());
340
+
325
341
return res;
326
342
}
327
343
0 commit comments