Skip to content

Commit 1ca3044

Browse files
a-ghorbanijhen0409
andauthored
feat: sync llama.cpp (#79)
* feat: sync llama.cpp * fix: fix submodule update - as part of llama.cpp sync * chore: remove unnecessary comment * chore(example): revert unnecessary changes * feat: sync llama.cpp * fix: remove tfs_z ref: ggml-org/llama.cpp#10071 * fix(cpp): skip gpu device if n_gpu_layers <= 0 ref: ggml-org/llama.cpp#10132 --------- Co-authored-by: Jhen-Jie Hong <[email protected]>
1 parent f35545b commit 1ca3044

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+21474
-13484
lines changed

android/src/main/CMakeLists.txt

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,23 @@ include_directories(${RNLLAMA_LIB_DIR})
99

1010
set(
1111
SOURCE_FILES
12+
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
13+
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
14+
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
15+
${RNLLAMA_LIB_DIR}/log.cpp
16+
17+
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
1218
${RNLLAMA_LIB_DIR}/ggml-alloc.c
13-
${RNLLAMA_LIB_DIR}/ggml-backend.c
19+
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
1420
${RNLLAMA_LIB_DIR}/ggml.c
1521
${RNLLAMA_LIB_DIR}/ggml-quants.c
1622
${RNLLAMA_LIB_DIR}/common.cpp
17-
${RNLLAMA_LIB_DIR}/grammar-parser.cpp
1823
${RNLLAMA_LIB_DIR}/json.hpp
1924
${RNLLAMA_LIB_DIR}/json-schema-to-grammar.cpp
2025
${RNLLAMA_LIB_DIR}/sampling.cpp
2126
${RNLLAMA_LIB_DIR}/unicode-data.cpp
2227
${RNLLAMA_LIB_DIR}/unicode.cpp
2328
${RNLLAMA_LIB_DIR}/llama.cpp
24-
${RNLLAMA_LIB_DIR}/llama-vocab.cpp
25-
${RNLLAMA_LIB_DIR}/llama-sampling.cpp
26-
${RNLLAMA_LIB_DIR}/llama-grammar.cpp
2729
${RNLLAMA_LIB_DIR}/sgemm.cpp
2830
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
2931
${RNLLAMA_LIB_DIR}/rn-llama.hpp
@@ -65,10 +67,20 @@ build_library("rnllama" "")
6567

6668
if (${ANDROID_ABI} STREQUAL "arm64-v8a")
6769
# ARM64 targets
70+
build_library("rnllama_v8_4_fp16_dotprod_sve" "-march=armv8.4-a+fp16+dotprod+sve")
71+
build_library("rnllama_v8_4_fp16_dotprod_i8mm_sve" "-march=armv8.4-a+fp16+dotprod+i8mm+sve")
72+
build_library("rnllama_v8_4_fp16_dotprod_i8mm" "-march=armv8.4-a+fp16+dotprod+i8mm")
6873
build_library("rnllama_v8_4_fp16_dotprod" "-march=armv8.4-a+fp16+dotprod")
6974
build_library("rnllama_v8_2_fp16_dotprod" "-march=armv8.2-a+fp16+dotprod")
7075
build_library("rnllama_v8_2_fp16" "-march=armv8.2-a+fp16")
7176
build_library("rnllama_v8" "-march=armv8-a")
77+
78+
# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
79+
# llama.cpp will deal with the cpu features
80+
# build_library("rnllama_v8_7" "-march=armv8.7-a")
81+
# TODO: Add support runtime check for cpu features
82+
# At the moment runtime check is failing.
83+
7284
elseif (${ANDROID_ABI} STREQUAL "x86_64")
7385
# x86_64 target
7486
build_library("rnllama_x86_64" "-march=x86-64" "-mtune=intel" "-msse4.2" "-mpopcnt")

android/src/main/java/com/rnllama/LlamaContext.java

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
3535
if (!params.hasKey("model")) {
3636
throw new IllegalArgumentException("Missing required parameter: model");
3737
}
38+
Log.d(NAME, "Setting log callback");
39+
logToAndroid();
3840
this.id = id;
3941
this.context = initContext(
4042
// String model,
@@ -53,6 +55,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
5355
params.hasKey("use_mlock") ? params.getBoolean("use_mlock") : true,
5456
// boolean use_mmap,
5557
params.hasKey("use_mmap") ? params.getBoolean("use_mmap") : true,
58+
//boolean vocab_only,
59+
params.hasKey("vocab_only") ? params.getBoolean("vocab_only") : false,
5660
// String lora,
5761
params.hasKey("lora") ? params.getString("lora") : "",
5862
// float lora_scaled,
@@ -181,6 +185,10 @@ public WritableMap completion(ReadableMap params) {
181185
params.hasKey("top_p") ? (float) params.getDouble("top_p") : 0.95f,
182186
// float min_p,
183187
params.hasKey("min_p") ? (float) params.getDouble("min_p") : 0.05f,
188+
// float xtc_threshold,
189+
params.hasKey("xtc_threshold") ? (float) params.getDouble("xtc_threshold") : 0.00f,
190+
// float xtc_probability,
191+
params.hasKey("xtc_probability") ? (float) params.getDouble("xtc_probability") : 0.00f,
184192
// float tfs_z,
185193
params.hasKey("tfs_z") ? (float) params.getDouble("tfs_z") : 1.00f,
186194
// float typical_p,
@@ -248,16 +256,34 @@ public void release() {
248256

249257
static {
250258
Log.d(NAME, "Primary ABI: " + Build.SUPPORTED_ABIS[0]);
251-
if (LlamaContext.isArm64V8a()) {
252-
String cpuFeatures = LlamaContext.getCpuFeatures();
253-
Log.d(NAME, "CPU features: " + cpuFeatures);
254-
255-
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
256-
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
257-
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
258-
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
259259

260-
if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
260+
String cpuFeatures = LlamaContext.getCpuFeatures();
261+
Log.d(NAME, "CPU features: " + cpuFeatures);
262+
boolean hasFp16 = cpuFeatures.contains("fp16") || cpuFeatures.contains("fphp");
263+
boolean hasDotProd = cpuFeatures.contains("dotprod") || cpuFeatures.contains("asimddp");
264+
boolean hasSve = cpuFeatures.contains("sve");
265+
boolean hasI8mm = cpuFeatures.contains("i8mm");
266+
boolean isAtLeastArmV82 = cpuFeatures.contains("asimd") && cpuFeatures.contains("crc32") && cpuFeatures.contains("aes");
267+
boolean isAtLeastArmV84 = cpuFeatures.contains("dcpop") && cpuFeatures.contains("uscat");
268+
Log.d(NAME, "- hasFp16: " + hasFp16);
269+
Log.d(NAME, "- hasDotProd: " + hasDotProd);
270+
Log.d(NAME, "- hasSve: " + hasSve);
271+
Log.d(NAME, "- hasI8mm: " + hasI8mm);
272+
Log.d(NAME, "- isAtLeastArmV82: " + isAtLeastArmV82);
273+
Log.d(NAME, "- isAtLeastArmV84: " + isAtLeastArmV84);
274+
275+
// TODO: Add runtime check for cpu features
276+
if (LlamaContext.isArm64V8a()) {
277+
if (isAtLeastArmV84 && hasSve && hasI8mm && hasFp16 && hasDotProd) {
278+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm_sve.so");
279+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm_sve");
280+
} else if (isAtLeastArmV84 && hasSve && hasFp16 && hasDotProd) {
281+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_sve.so");
282+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_sve");
283+
} else if (isAtLeastArmV84 && hasI8mm && hasFp16 && hasDotProd) {
284+
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod_i8mm.so");
285+
System.loadLibrary("rnllama_v8_4_fp16_dotprod_i8mm");
286+
} else if (isAtLeastArmV84 && hasFp16 && hasDotProd) {
261287
Log.d(NAME, "Loading librnllama_v8_4_fp16_dotprod.so");
262288
System.loadLibrary("rnllama_v8_4_fp16_dotprod");
263289
} else if (isAtLeastArmV82 && hasFp16 && hasDotProd) {
@@ -270,14 +296,16 @@ public void release() {
270296
Log.d(NAME, "Loading librnllama_v8.so");
271297
System.loadLibrary("rnllama_v8");
272298
}
299+
// Log.d(NAME, "Loading librnllama_v8_7.so with runtime feature detection");
300+
// System.loadLibrary("rnllama_v8_7");
273301
} else if (LlamaContext.isX86_64()) {
274-
Log.d(NAME, "Loading librnllama_x86_64.so");
275-
System.loadLibrary("rnllama_x86_64");
302+
Log.d(NAME, "Loading librnllama_x86_64.so");
303+
System.loadLibrary("rnllama_x86_64");
276304
} else {
277-
Log.d(NAME, "Loading default librnllama.so");
278-
System.loadLibrary("rnllama");
305+
Log.d(NAME, "Loading default librnllama.so");
306+
System.loadLibrary("rnllama");
279307
}
280-
}
308+
}
281309

282310
private static boolean isArm64V8a() {
283311
return Build.SUPPORTED_ABIS[0].equals("arm64-v8a");
@@ -316,6 +344,7 @@ protected static native long initContext(
316344
int n_gpu_layers, // TODO: Support this
317345
boolean use_mlock,
318346
boolean use_mmap,
347+
boolean vocab_only,
319348
String lora,
320349
float lora_scaled,
321350
float rope_freq_base,
@@ -357,6 +386,8 @@ protected static native WritableMap doCompletion(
357386
int top_k,
358387
float top_p,
359388
float min_p,
389+
float xtc_threshold,
390+
float xtc_probability,
360391
float tfs_z,
361392
float typical_p,
362393
int seed,
@@ -373,4 +404,5 @@ protected static native WritableMap doCompletion(
373404
protected static native WritableMap embedding(long contextPtr, String text);
374405
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
375406
protected static native void freeContext(long contextPtr);
407+
protected static native void logToAndroid();
376408
}

android/src/main/jni.cpp

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// #include <android/asset_manager_jni.h>
44
#include <android/log.h>
55
#include <cstdlib>
6+
#include <ctime>
67
#include <sys/sysinfo.h>
78
#include <string>
89
#include <thread>
@@ -21,6 +22,13 @@ static inline int min(int a, int b) {
2122
return (a < b) ? a : b;
2223
}
2324

25+
static void log_callback(lm_ggml_log_level level, const char * fmt, void * data) {
26+
if (level == LM_GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
27+
else if (level == LM_GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
28+
else if (level == LM_GGML_LOG_LEVEL_WARN) __android_log_print(ANDROID_LOG_WARN, TAG, fmt, data);
29+
else __android_log_print(ANDROID_LOG_DEFAULT, TAG, fmt, data);
30+
}
31+
2432
extern "C" {
2533

2634
// Method to create WritableMap
@@ -139,14 +147,20 @@ Java_com_rnllama_LlamaContext_initContext(
139147
jint n_gpu_layers, // TODO: Support this
140148
jboolean use_mlock,
141149
jboolean use_mmap,
150+
jboolean vocab_only,
142151
jstring lora_str,
143152
jfloat lora_scaled,
144153
jfloat rope_freq_base,
145154
jfloat rope_freq_scale
146155
) {
147156
UNUSED(thiz);
148157

149-
gpt_params defaultParams;
158+
common_params defaultParams;
159+
160+
defaultParams.vocab_only = vocab_only;
161+
if(vocab_only) {
162+
defaultParams.warmup = false;
163+
}
150164

151165
const char *model_path_chars = env->GetStringUTFChars(model_path_str, nullptr);
152166
defaultParams.model = model_path_chars;
@@ -159,7 +173,7 @@ Java_com_rnllama_LlamaContext_initContext(
159173
int max_threads = std::thread::hardware_concurrency();
160174
// Use 2 threads by default on 4-core devices, 4 threads on more cores
161175
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
162-
defaultParams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
176+
defaultParams.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
163177

164178
defaultParams.n_gpu_layers = n_gpu_layers;
165179

@@ -235,7 +249,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
235249
UNUSED(thiz);
236250
auto llama = context_map[(long) context_ptr];
237251

238-
std::vector<llama_chat_msg> chat;
252+
std::vector<common_chat_msg> chat;
239253

240254
int messages_len = env->GetArrayLength(messages);
241255
for (int i = 0; i < messages_len; i++) {
@@ -259,7 +273,7 @@ Java_com_rnllama_LlamaContext_getFormattedChat(
259273
}
260274

261275
const char *tmpl_chars = env->GetStringUTFChars(chat_template, nullptr);
262-
std::string formatted_chat = llama_chat_apply_template(llama->model, tmpl_chars, chat, true);
276+
std::string formatted_chat = common_chat_apply_template(llama->model, tmpl_chars, chat, true);
263277

264278
return env->NewStringUTF(formatted_chat.c_str());
265279
}
@@ -364,7 +378,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
364378
jint top_k,
365379
jfloat top_p,
366380
jfloat min_p,
367-
jfloat tfs_z,
381+
jfloat xtc_threshold,
382+
jfloat xtc_probability,
368383
jfloat typical_p,
369384
jint seed,
370385
jobjectArray stop,
@@ -377,18 +392,18 @@ Java_com_rnllama_LlamaContext_doCompletion(
377392

378393
llama->rewind();
379394

380-
llama_reset_timings(llama->ctx);
395+
//llama_reset_timings(llama->ctx);
381396

382397
llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
383-
llama->params.seed = seed;
398+
llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed;
384399

385400
int max_threads = std::thread::hardware_concurrency();
386401
// Use 2 threads by default on 4-core devices, 4 threads on more cores
387402
int default_n_threads = max_threads == 4 ? 2 : min(4, max_threads);
388-
llama->params.n_threads = n_threads > 0 ? n_threads : default_n_threads;
403+
llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads;
389404

390405
llama->params.n_predict = n_predict;
391-
llama->params.ignore_eos = ignore_eos;
406+
llama->params.sparams.ignore_eos = ignore_eos;
392407

393408
auto & sparams = llama->params.sparams;
394409
sparams.temp = temperature;
@@ -403,14 +418,15 @@ Java_com_rnllama_LlamaContext_doCompletion(
403418
sparams.top_k = top_k;
404419
sparams.top_p = top_p;
405420
sparams.min_p = min_p;
406-
sparams.tfs_z = tfs_z;
407-
sparams.typical_p = typical_p;
421+
sparams.typ_p = typical_p;
408422
sparams.n_probs = n_probs;
409423
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
424+
sparams.xtc_threshold = xtc_threshold;
425+
sparams.xtc_probability = xtc_probability;
410426

411427
sparams.logit_bias.clear();
412428
if (ignore_eos) {
413-
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
429+
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
414430
}
415431

416432
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
@@ -424,9 +440,9 @@ Java_com_rnllama_LlamaContext_doCompletion(
424440
llama_token tok = static_cast<llama_token>(doubleArray[0]);
425441
if (tok >= 0 && tok < n_vocab) {
426442
if (doubleArray[1] != 0) { // If the second element is not false (0)
427-
sparams.logit_bias[tok] = doubleArray[1];
443+
sparams.logit_bias[tok].bias = doubleArray[1];
428444
} else {
429-
sparams.logit_bias[tok] = -INFINITY;
445+
sparams.logit_bias[tok].bias = -INFINITY;
430446
}
431447
}
432448

@@ -460,7 +476,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
460476
if (token_with_probs.tok == -1 || llama->incomplete) {
461477
continue;
462478
}
463-
const std::string token_text = llama_token_to_piece(llama->ctx, token_with_probs.tok);
479+
const std::string token_text = common_token_to_piece(llama->ctx, token_with_probs.tok);
464480

465481
size_t pos = std::min(sent_count, llama->generated_text.size());
466482

@@ -495,7 +511,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
495511
putString(env, tokenResult, "token", to_send.c_str());
496512

497513
if (llama->params.sparams.n_probs > 0) {
498-
const std::vector<llama_token> to_send_toks = llama_tokenize(llama->ctx, to_send, false);
514+
const std::vector<llama_token> to_send_toks = common_tokenize(llama->ctx, to_send, false);
499515
size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size());
500516
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size());
501517
if (probs_pos < probs_stop_pos) {
@@ -512,7 +528,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
512528
}
513529
}
514530

515-
llama_print_timings(llama->ctx);
531+
llama_perf_context_print(llama->ctx);
516532
llama->is_predicting = false;
517533

518534
auto result = createWriteableMap(env);
@@ -527,16 +543,17 @@ Java_com_rnllama_LlamaContext_doCompletion(
527543
putString(env, result, "stopping_word", llama->stopping_word.c_str());
528544
putInt(env, result, "tokens_cached", llama->n_past);
529545

530-
const auto timings = llama_get_timings(llama->ctx);
546+
const auto timings_token = llama_perf_context(llama -> ctx);
547+
531548
auto timingsResult = createWriteableMap(env);
532-
putInt(env, timingsResult, "prompt_n", timings.n_p_eval);
533-
putInt(env, timingsResult, "prompt_ms", timings.t_p_eval_ms);
534-
putInt(env, timingsResult, "prompt_per_token_ms", timings.t_p_eval_ms / timings.n_p_eval);
535-
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
536-
putInt(env, timingsResult, "predicted_n", timings.n_eval);
537-
putInt(env, timingsResult, "predicted_ms", timings.t_eval_ms);
538-
putInt(env, timingsResult, "predicted_per_token_ms", timings.t_eval_ms / timings.n_eval);
539-
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings.t_eval_ms * timings.n_eval);
549+
putInt(env, timingsResult, "prompt_n", timings_token.n_p_eval);
550+
putInt(env, timingsResult, "prompt_ms", timings_token.t_p_eval_ms);
551+
putInt(env, timingsResult, "prompt_per_token_ms", timings_token.t_p_eval_ms / timings_token.n_p_eval);
552+
putDouble(env, timingsResult, "prompt_per_second", 1e3 / timings_token.t_p_eval_ms * timings_token.n_p_eval);
553+
putInt(env, timingsResult, "predicted_n", timings_token.n_eval);
554+
putInt(env, timingsResult, "predicted_ms", timings_token.t_eval_ms);
555+
putInt(env, timingsResult, "predicted_per_token_ms", timings_token.t_eval_ms / timings_token.n_eval);
556+
putDouble(env, timingsResult, "predicted_per_second", 1e3 / timings_token.t_eval_ms * timings_token.n_eval);
540557

541558
putMap(env, result, "timings", timingsResult);
542559

@@ -569,7 +586,7 @@ Java_com_rnllama_LlamaContext_tokenize(
569586

570587
const char *text_chars = env->GetStringUTFChars(text, nullptr);
571588

572-
const std::vector<llama_token> toks = llama_tokenize(
589+
const std::vector<llama_token> toks = common_tokenize(
573590
llama->ctx,
574591
text_chars,
575592
false
@@ -623,7 +640,7 @@ Java_com_rnllama_LlamaContext_embedding(
623640

624641
llama->rewind();
625642

626-
llama_reset_timings(llama->ctx);
643+
llama_perf_context_reset(llama->ctx);
627644

628645
llama->params.prompt = text_chars;
629646

@@ -681,9 +698,16 @@ Java_com_rnllama_LlamaContext_freeContext(
681698
}
682699
if (llama->ctx_sampling != nullptr)
683700
{
684-
llama_sampling_free(llama->ctx_sampling);
701+
common_sampler_free(llama->ctx_sampling);
685702
}
686703
context_map.erase((long) llama->ctx);
687704
}
688705

706+
JNIEXPORT void JNICALL
707+
Java_com_rnllama_LlamaContext_logToAndroid(JNIEnv *env, jobject thiz) {
708+
UNUSED(env);
709+
UNUSED(thiz);
710+
llama_log_set(log_callback, NULL);
711+
}
712+
689713
} // extern "C"

0 commit comments

Comments
 (0)