Skip to content

Commit a6cd1dc

Browse files
authored
Migrate java api
Differential Revision: D71588845 Pull Request resolved: #9478
1 parent a828307 commit a6cd1dc

File tree

7 files changed

+322
-77
lines changed

7 files changed

+322
-77
lines changed

extension/android/BUCK

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ fb_android_library(
2525
srcs = [
2626
"src/main/java/org/pytorch/executorch/LlamaCallback.java",
2727
"src/main/java/org/pytorch/executorch/LlamaModule.java",
28+
"src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java",
29+
"src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java",
2830
],
2931
autoglob = False,
3032
language = "JAVA",

extension/android/jni/jni_layer.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -408,14 +408,14 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
408408
} // namespace executorch::extension
409409

410410
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
411-
extern void register_natives_for_llama();
411+
extern void register_natives_for_llm();
412412
#else
413-
// No op if we don't build llama
414-
void register_natives_for_llama() {}
413+
// No op if we don't build LLM
414+
void register_natives_for_llm() {}
415415
#endif
416416
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
417417
return facebook::jni::initialize(vm, [] {
418418
executorch::extension::ExecuTorchJni::registerNatives();
419-
register_natives_for_llama();
419+
register_natives_for_llm();
420420
});
421421
}

extension/android/jni/jni_layer_llama.cpp

+19-20
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ std::string token_buffer;
7575

7676
namespace executorch_jni {
7777

78-
class ExecuTorchLlamaCallbackJni
79-
: public facebook::jni::JavaClass<ExecuTorchLlamaCallbackJni> {
78+
class ExecuTorchLlmCallbackJni
79+
: public facebook::jni::JavaClass<ExecuTorchLlmCallbackJni> {
8080
public:
8181
constexpr static const char* kJavaDescriptor =
82-
"Lorg/pytorch/executorch/LlamaCallback;";
82+
"Lorg/pytorch/executorch/extension/llm/LlmCallback;";
8383

8484
void onResult(std::string result) const {
85-
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
85+
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
8686
static const auto method =
8787
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onResult");
8888

@@ -99,7 +99,7 @@ class ExecuTorchLlamaCallbackJni
9999
}
100100

101101
void onStats(const llm::Stats& result) const {
102-
static auto cls = ExecuTorchLlamaCallbackJni::javaClassStatic();
102+
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
103103
static const auto method = cls->getMethod<void(jfloat)>("onStats");
104104
double eval_time =
105105
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
@@ -111,8 +111,7 @@ class ExecuTorchLlamaCallbackJni
111111
}
112112
};
113113

114-
class ExecuTorchLlamaJni
115-
: public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
114+
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
116115
private:
117116
friend HybridBase;
118117
int model_type_category_;
@@ -121,7 +120,7 @@ class ExecuTorchLlamaJni
121120

122121
public:
123122
constexpr static auto kJavaDescriptor =
124-
"Lorg/pytorch/executorch/LlamaModule;";
123+
"Lorg/pytorch/executorch/extension/llm/LlmModule;";
125124

126125
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
127126
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
@@ -142,7 +141,7 @@ class ExecuTorchLlamaJni
142141
data_path);
143142
}
144143

145-
ExecuTorchLlamaJni(
144+
ExecuTorchLlmJni(
146145
jint model_type_category,
147146
facebook::jni::alias_ref<jstring> model_path,
148147
facebook::jni::alias_ref<jstring> tokenizer_path,
@@ -197,7 +196,7 @@ class ExecuTorchLlamaJni
197196
jint channels,
198197
facebook::jni::alias_ref<jstring> prompt,
199198
jint seq_len,
200-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
199+
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
201200
jboolean echo) {
202201
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
203202
auto image_size = image->size();
@@ -296,7 +295,7 @@ class ExecuTorchLlamaJni
296295
facebook::jni::alias_ref<jstring> prompt,
297296
jint seq_len,
298297
jlong start_pos,
299-
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback,
298+
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
300299
jboolean echo) {
301300
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
302301
return static_cast<jint>(Error::NotSupported);
@@ -329,22 +328,22 @@ class ExecuTorchLlamaJni
329328

330329
static void registerNatives() {
331330
registerHybrid({
332-
makeNativeMethod("initHybrid", ExecuTorchLlamaJni::initHybrid),
333-
makeNativeMethod("generate", ExecuTorchLlamaJni::generate),
334-
makeNativeMethod("stop", ExecuTorchLlamaJni::stop),
335-
makeNativeMethod("load", ExecuTorchLlamaJni::load),
331+
makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid),
332+
makeNativeMethod("generate", ExecuTorchLlmJni::generate),
333+
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
334+
makeNativeMethod("load", ExecuTorchLlmJni::load),
336335
makeNativeMethod(
337-
"prefillImagesNative", ExecuTorchLlamaJni::prefill_images),
336+
"prefillImagesNative", ExecuTorchLlmJni::prefill_images),
338337
makeNativeMethod(
339-
"prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt),
338+
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
340339
makeNativeMethod(
341-
"generateFromPos", ExecuTorchLlamaJni::generate_from_pos),
340+
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
342341
});
343342
}
344343
};
345344

346345
} // namespace executorch_jni
347346

348-
void register_natives_for_llama() {
349-
executorch_jni::ExecuTorchLlamaJni::registerNatives();
347+
void register_natives_for_llm() {
348+
executorch_jni::ExecuTorchLlmJni::registerNatives();
350349
}

extension/android/src/main/java/org/pytorch/executorch/LlamaCallback.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
package org.pytorch.executorch;
1010

1111
import com.facebook.jni.annotations.DoNotStrip;
12-
import org.pytorch.executorch.annotations.Experimental;
1312

1413
/**
1514
* Callback interface for Llama model. Users can implement this interface to receive the generated
1615
* tokens and statistics.
1716
*
18-
* <p>Warning: These APIs are experimental and subject to change without notice
17+
* <p>Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmCallback} instead.
1918
*/
20-
@Experimental
19+
@Deprecated
2120
public interface LlamaCallback {
2221
/**
2322
* Called when a new result is available from JNI. Users will keep getting onResult() invocations

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

+59-50
Original file line numberDiff line numberDiff line change
@@ -8,59 +8,45 @@
88

99
package org.pytorch.executorch;
1010

11-
import com.facebook.jni.HybridData;
12-
import com.facebook.jni.annotations.DoNotStrip;
13-
import com.facebook.soloader.nativeloader.NativeLoader;
14-
import com.facebook.soloader.nativeloader.SystemDelegate;
15-
import org.pytorch.executorch.annotations.Experimental;
11+
import org.pytorch.executorch.extension.llm.LlmCallback;
12+
import org.pytorch.executorch.extension.llm.LlmModule;
1613

1714
/**
1815
* LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to
1916
* generate text from the model.
2017
*
21-
* <p>Warning: These APIs are experimental and subject to change without notice
18+
* <p>Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmModule} instead.
2219
*/
23-
@Experimental
20+
@Deprecated
2421
public class LlamaModule {
2522

2623
public static final int MODEL_TYPE_TEXT = 1;
2724
public static final int MODEL_TYPE_TEXT_VISION = 2;
2825

29-
static {
30-
if (!NativeLoader.isInitialized()) {
31-
NativeLoader.init(new SystemDelegate());
32-
}
33-
NativeLoader.loadLibrary("executorch");
34-
}
35-
36-
private final HybridData mHybridData;
26+
private LlmModule mModule;
3727
private static final int DEFAULT_SEQ_LEN = 128;
3828
private static final boolean DEFAULT_ECHO = true;
3929

40-
@DoNotStrip
41-
private static native HybridData initHybrid(
42-
int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath);
43-
4430
/** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature. */
4531
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
46-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null);
32+
mModule = new LlmModule(modulePath, tokenizerPath, temperature);
4733
}
4834

4935
/**
5036
* Constructs a LLAMA Module for a model with given model path, tokenizer, temperature and data
5137
* path.
5238
*/
5339
public LlamaModule(String modulePath, String tokenizerPath, float temperature, String dataPath) {
54-
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath);
40+
mModule = new LlmModule(modulePath, tokenizerPath, temperature, dataPath);
5541
}
5642

5743
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
5844
public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
59-
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null);
45+
mModule = new LlmModule(modelType, modulePath, tokenizerPath, temperature);
6046
}
6147

6248
public void resetNative() {
63-
mHybridData.resetNative();
49+
mModule.resetNative();
6450
}
6551

6652
/**
@@ -70,7 +56,7 @@ public void resetNative() {
7056
* @param llamaCallback callback object to receive results.
7157
*/
7258
public int generate(String prompt, LlamaCallback llamaCallback) {
73-
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
59+
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO);
7460
}
7561

7662
/**
@@ -119,16 +105,35 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, bool
119105
* @param llamaCallback callback object to receive results.
120106
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
121107
*/
122-
@DoNotStrip
123-
public native int generate(
108+
public int generate(
124109
int[] image,
125110
int width,
126111
int height,
127112
int channels,
128113
String prompt,
129114
int seqLen,
130115
LlamaCallback llamaCallback,
131-
boolean echo);
116+
boolean echo) {
117+
return mModule.generate(
118+
image,
119+
width,
120+
height,
121+
channels,
122+
prompt,
123+
seqLen,
124+
new LlmCallback() {
125+
@Override
126+
public void onResult(String result) {
127+
llamaCallback.onResult(result);
128+
}
129+
130+
@Override
131+
public void onStats(float tps) {
132+
llamaCallback.onStats(tps);
133+
}
134+
},
135+
echo);
136+
}
132137

133138
/**
134139
* Prefill an LLaVA Module with the given images input.
@@ -142,17 +147,9 @@ public native int generate(
142147
* @throws RuntimeException if the prefill failed
143148
*/
144149
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
145-
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
146-
if (nativeResult[0] != 0) {
147-
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
148-
}
149-
return nativeResult[1];
150+
return mModule.prefillImages(image, width, height, channels, startPos);
150151
}
151152

152-
// returns a tuple of (status, updated startPos)
153-
private native long[] prefillImagesNative(
154-
int[] image, int width, int height, int channels, long startPos);
155-
156153
/**
157154
* Prefill an LLaVA Module with the given text input.
158155
*
@@ -165,16 +162,9 @@ private native long[] prefillImagesNative(
165162
* @throws RuntimeException if the prefill failed
166163
*/
167164
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
168-
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
169-
if (nativeResult[0] != 0) {
170-
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
171-
}
172-
return nativeResult[1];
165+
return mModule.prefillPrompt(prompt, startPos, bos, eos);
173166
}
174167

175-
// returns a tuple of (status, updated startPos)
176-
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
177-
178168
/**
179169
* Generate tokens from the given prompt, starting from the given position.
180170
*
@@ -185,14 +175,33 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
185175
* @param echo indicate whether to echo the input prompt or not.
186176
* @return The error code.
187177
*/
188-
public native int generateFromPos(
189-
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo);
178+
public int generateFromPos(
179+
String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo) {
180+
return mModule.generateFromPos(
181+
prompt,
182+
seqLen,
183+
startPos,
184+
new LlmCallback() {
185+
@Override
186+
public void onResult(String result) {
187+
callback.onResult(result);
188+
}
189+
190+
@Override
191+
public void onStats(float tps) {
192+
callback.onStats(tps);
193+
}
194+
},
195+
echo);
196+
}
190197

191198
/** Stop current generate() before it finishes. */
192-
@DoNotStrip
193-
public native void stop();
199+
public void stop() {
200+
mModule.stop();
201+
}
194202

195203
/** Force loading the module. Otherwise the model is loaded during first generate(). */
196-
@DoNotStrip
197-
public native int load();
204+
public int load() {
205+
return mModule.load();
206+
}
198207
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.llm;
10+
11+
import com.facebook.jni.annotations.DoNotStrip;
12+
import org.pytorch.executorch.annotations.Experimental;
13+
14+
/**
15+
* Callback interface for Llama model. Users can implement this interface to receive the generated
16+
* tokens and statistics.
17+
*
18+
* <p>Warning: These APIs are experimental and subject to change without notice
19+
*/
20+
@Experimental
21+
public interface LlmCallback {
22+
/**
23+
* Called when a new result is available from JNI. Users will keep getting onResult() invocations
24+
* until generate() finishes.
25+
*
26+
* @param result Last generated token
27+
*/
28+
@DoNotStrip
29+
public void onResult(String result);
30+
31+
/**
32+
* Called when the statistics for the generate() is available.
33+
*
34+
* @param tps Tokens/second for generated tokens.
35+
*/
36+
@DoNotStrip
37+
public void onStats(float tps);
38+
}

0 commit comments

Comments
 (0)