3
3
// #include <android/asset_manager_jni.h>
4
4
#include < android/log.h>
5
5
#include < cstdlib>
6
+ #include < ctime>
6
7
#include < sys/sysinfo.h>
7
8
#include < string>
8
9
#include < thread>
@@ -21,6 +22,13 @@ static inline int min(int a, int b) {
21
22
return (a < b) ? a : b;
22
23
}
23
24
25
+ static void log_callback (lm_ggml_log_level level, const char * fmt, void * data) {
26
+ if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print (ANDROID_LOG_ERROR, TAG, fmt, data);
27
+ else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print (ANDROID_LOG_INFO, TAG, fmt, data);
28
+ else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print (ANDROID_LOG_WARN, TAG, fmt, data);
29
+ else __android_log_print (ANDROID_LOG_DEFAULT, TAG, fmt, data);
30
+ }
31
+
24
32
extern " C" {
25
33
26
34
// Method to create WritableMap
@@ -139,14 +147,20 @@ Java_com_rnllama_LlamaContext_initContext(
139
147
jint n_gpu_layers, // TODO: Support this
140
148
jboolean use_mlock,
141
149
jboolean use_mmap,
150
+ jboolean vocab_only,
142
151
jstring lora_str,
143
152
jfloat lora_scaled,
144
153
jfloat rope_freq_base,
145
154
jfloat rope_freq_scale
146
155
) {
147
156
UNUSED (thiz);
148
157
149
- gpt_params defaultParams;
158
+ common_params defaultParams;
159
+
160
+ defaultParams.vocab_only = vocab_only;
161
+ if (vocab_only) {
162
+ defaultParams.warmup = false ;
163
+ }
150
164
151
165
const char *model_path_chars = env->GetStringUTFChars (model_path_str, nullptr );
152
166
defaultParams.model = model_path_chars;
@@ -159,7 +173,7 @@ Java_com_rnllama_LlamaContext_initContext(
159
173
int max_threads = std::thread::hardware_concurrency ();
160
174
// Use 2 threads by default on 4-core devices, 4 threads on more cores
161
175
int default_n_threads = max_threads == 4 ? 2 : min (4 , max_threads);
162
- defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
176
+ defaultParams.cpuparams . n_threads = n_threads > 0 ? n_threads : default_n_threads;
163
177
164
178
defaultParams.n_gpu_layers = n_gpu_layers;
165
179
@@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
235
249
UNUSED (thiz);
236
250
auto llama = context_map[(long ) context_ptr];
237
251
238
- std::vector<llama_chat_msg > chat;
252
+ std::vector<common_chat_msg > chat;
239
253
240
254
int messages_len = env->GetArrayLength (messages);
241
255
for (int i = 0 ; i < messages_len; i++) {
@@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
259
273
}
260
274
261
275
const char *tmpl_chars = env->GetStringUTFChars (chat_template, nullptr );
262
- std::string formatted_chat = llama_chat_apply_template (llama->model , tmpl_chars, chat, true );
276
+ std::string formatted_chat = common_chat_apply_template (llama->model , tmpl_chars, chat, true );
263
277
264
278
return env->NewStringUTF (formatted_chat.c_str ());
265
279
}
@@ -364,7 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
364
378
jint top_k,
365
379
jfloat top_p,
366
380
jfloat min_p,
367
- jfloat tfs_z,
381
+ jfloat xtc_threshold,
382
+ jfloat xtc_probability,
368
383
jfloat typical_p,
369
384
jint seed,
370
385
jobjectArray stop,
@@ -377,18 +392,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
377
392
378
393
llama->rewind ();
379
394
380
- llama_reset_timings (llama->ctx );
395
+ // llama_reset_timings(llama->ctx);
381
396
382
397
llama->params .prompt = env->GetStringUTFChars (prompt, nullptr );
383
- llama->params .seed = seed;
398
+ llama->params .sparams . seed = (seed == - 1 ) ? time ( NULL ) : seed;
384
399
385
400
int max_threads = std::thread::hardware_concurrency ();
386
401
// Use 2 threads by default on 4-core devices, 4 threads on more cores
387
402
int default_n_threads = max_threads == 4 ? 2 : min (4 , max_threads);
388
- llama->params .n_threads = n_threads > 0 ? n_threads : default_n_threads;
403
+ llama->params .cpuparams . n_threads = n_threads > 0 ? n_threads : default_n_threads;
389
404
390
405
llama->params .n_predict = n_predict;
391
- llama->params .ignore_eos = ignore_eos;
406
+ llama->params .sparams . ignore_eos = ignore_eos;
392
407
393
408
auto & sparams = llama->params .sparams ;
394
409
sparams.temp = temperature;
@@ -403,14 +418,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
403
418
sparams.top_k = top_k;
404
419
sparams.top_p = top_p;
405
420
sparams.min_p = min_p;
406
- sparams.tfs_z = tfs_z;
407
- sparams.typical_p = typical_p;
421
+ sparams.typ_p = typical_p;
408
422
sparams.n_probs = n_probs;
409
423
sparams.grammar = env->GetStringUTFChars (grammar, nullptr );
424
+ sparams.xtc_threshold = xtc_threshold;
425
+ sparams.xtc_probability = xtc_probability;
410
426
411
427
sparams.logit_bias .clear ();
412
428
if (ignore_eos) {
413
- sparams.logit_bias [llama_token_eos (llama->model )] = -INFINITY;
429
+ sparams.logit_bias [llama_token_eos (llama->model )]. bias = -INFINITY;
414
430
}
415
431
416
432
const int n_vocab = llama_n_vocab (llama_get_model (llama->ctx ));
@@ -424,9 +440,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
424
440
llama_token tok = static_cast <llama_token>(doubleArray[0 ]);
425
441
if (tok >= 0 && tok < n_vocab) {
426
442
if (doubleArray[1 ] != 0 ) { // If the second element is not false (0)
427
- sparams.logit_bias [tok] = doubleArray[1 ];
443
+ sparams.logit_bias [tok]. bias = doubleArray[1 ];
428
444
} else {
429
- sparams.logit_bias [tok] = -INFINITY;
445
+ sparams.logit_bias [tok]. bias = -INFINITY;
430
446
}
431
447
}
432
448
@@ -460,7 +476,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
460
476
if (token_with_probs.tok == -1 || llama->incomplete ) {
461
477
continue ;
462
478
}
463
- const std::string token_text = llama_token_to_piece (llama->ctx , token_with_probs.tok );
479
+ const std::string token_text = common_token_to_piece (llama->ctx , token_with_probs.tok );
464
480
465
481
size_t pos = std::min (sent_count, llama->generated_text .size ());
466
482
@@ -495,7 +511,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
495
511
putString (env, tokenResult, " token" , to_send.c_str ());
496
512
497
513
if (llama->params .sparams .n_probs > 0 ) {
498
- const std::vector<llama_token> to_send_toks = llama_tokenize (llama->ctx , to_send, false );
514
+ const std::vector<llama_token> to_send_toks = common_tokenize (llama->ctx , to_send, false );
499
515
size_t probs_pos = std::min (sent_token_probs_index, llama->generated_token_probs .size ());
500
516
size_t probs_stop_pos = std::min (sent_token_probs_index + to_send_toks.size (), llama->generated_token_probs .size ());
501
517
if (probs_pos < probs_stop_pos) {
@@ -512,7 +528,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
512
528
}
513
529
}
514
530
515
- llama_print_timings (llama->ctx );
531
+ llama_perf_context_print (llama->ctx );
516
532
llama->is_predicting = false ;
517
533
518
534
auto result = createWriteableMap (env);
@@ -527,16 +543,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
527
543
putString (env, result, " stopping_word" , llama->stopping_word .c_str ());
528
544
putInt (env, result, " tokens_cached" , llama->n_past );
529
545
530
- const auto timings = llama_get_timings (llama->ctx );
546
+ const auto timings_token = llama_perf_context (llama -> ctx);
547
+
531
548
auto timingsResult = createWriteableMap (env);
532
- putInt (env, timingsResult, " prompt_n" , timings .n_p_eval );
533
- putInt (env, timingsResult, " prompt_ms" , timings .t_p_eval_ms );
534
- putInt (env, timingsResult, " prompt_per_token_ms" , timings .t_p_eval_ms / timings .n_p_eval );
535
- putDouble (env, timingsResult, " prompt_per_second" , 1e3 / timings .t_p_eval_ms * timings .n_p_eval );
536
- putInt (env, timingsResult, " predicted_n" , timings .n_eval );
537
- putInt (env, timingsResult, " predicted_ms" , timings .t_eval_ms );
538
- putInt (env, timingsResult, " predicted_per_token_ms" , timings .t_eval_ms / timings .n_eval );
539
- putDouble (env, timingsResult, " predicted_per_second" , 1e3 / timings .t_eval_ms * timings .n_eval );
549
+ putInt (env, timingsResult, " prompt_n" , timings_token .n_p_eval );
550
+ putInt (env, timingsResult, " prompt_ms" , timings_token .t_p_eval_ms );
551
+ putInt (env, timingsResult, " prompt_per_token_ms" , timings_token .t_p_eval_ms / timings_token .n_p_eval );
552
+ putDouble (env, timingsResult, " prompt_per_second" , 1e3 / timings_token .t_p_eval_ms * timings_token .n_p_eval );
553
+ putInt (env, timingsResult, " predicted_n" , timings_token .n_eval );
554
+ putInt (env, timingsResult, " predicted_ms" , timings_token .t_eval_ms );
555
+ putInt (env, timingsResult, " predicted_per_token_ms" , timings_token .t_eval_ms / timings_token .n_eval );
556
+ putDouble (env, timingsResult, " predicted_per_second" , 1e3 / timings_token .t_eval_ms * timings_token .n_eval );
540
557
541
558
putMap (env, result, " timings" , timingsResult);
542
559
@@ -569,7 +586,7 @@ Java_com_rnllama_LlamaContext_tokenize(
569
586
570
587
const char *text_chars = env->GetStringUTFChars (text, nullptr );
571
588
572
- const std::vector<llama_token> toks = llama_tokenize (
589
+ const std::vector<llama_token> toks = common_tokenize (
573
590
llama->ctx ,
574
591
text_chars,
575
592
false
@@ -623,7 +640,7 @@ Java_com_rnllama_LlamaContext_embedding(
623
640
624
641
llama->rewind ();
625
642
626
- llama_reset_timings (llama->ctx );
643
+ llama_perf_context_reset (llama->ctx );
627
644
628
645
llama->params .prompt = text_chars;
629
646
@@ -681,9 +698,16 @@ Java_com_rnllama_LlamaContext_freeContext(
681
698
}
682
699
if (llama->ctx_sampling != nullptr )
683
700
{
684
- llama_sampling_free (llama->ctx_sampling );
701
+ common_sampler_free (llama->ctx_sampling );
685
702
}
686
703
context_map.erase ((long ) llama->ctx );
687
704
}
688
705
706
+ JNIEXPORT void JNICALL
707
+ Java_com_rnllama_LlamaContext_logToAndroid (JNIEnv *env, jobject thiz) {
708
+ UNUSED (env);
709
+ UNUSED (thiz);
710
+ llama_log_set (log_callback, NULL );
711
+ }
712
+
689
713
} // extern "C"
0 commit comments