8
8
9
9
package org .pytorch .executorch ;
10
10
11
- import com .facebook .jni .HybridData ;
12
- import com .facebook .jni .annotations .DoNotStrip ;
13
- import com .facebook .soloader .nativeloader .NativeLoader ;
14
- import com .facebook .soloader .nativeloader .SystemDelegate ;
15
- import org .pytorch .executorch .annotations .Experimental ;
11
+ import org .pytorch .executorch .extension .llm .LlmCallback ;
12
+ import org .pytorch .executorch .extension .llm .LlmModule ;
16
13
17
14
/**
18
15
* LlamaModule is a wrapper around the Executorch Llama model. It provides a simple interface to
19
16
* generate text from the model.
20
17
*
21
- * <p>Warning: These APIs are experimental and subject to change without notice
18
+ * <p>Note: deprecated! Please use {@link org.pytorch.executorch.extension.llm.LlmModule} instead.
22
19
*/
23
- @ Experimental
20
+ @ Deprecated
24
21
public class LlamaModule {
25
22
26
23
public static final int MODEL_TYPE_TEXT = 1 ;
27
24
public static final int MODEL_TYPE_TEXT_VISION = 2 ;
28
25
29
- static {
30
- if (!NativeLoader .isInitialized ()) {
31
- NativeLoader .init (new SystemDelegate ());
32
- }
33
- NativeLoader .loadLibrary ("executorch" );
34
- }
35
-
36
- private final HybridData mHybridData ;
26
+ private LlmModule mModule ;
37
27
private static final int DEFAULT_SEQ_LEN = 128 ;
38
28
private static final boolean DEFAULT_ECHO = true ;
39
29
40
- @ DoNotStrip
41
- private static native HybridData initHybrid (
42
- int modelType , String modulePath , String tokenizerPath , float temperature , String dataPath );
43
-
44
30
/** Constructs a LLAMA Module for a model with given model path, tokenizer, temperature. */
45
31
public LlamaModule (String modulePath , String tokenizerPath , float temperature ) {
46
- mHybridData = initHybrid ( MODEL_TYPE_TEXT , modulePath , tokenizerPath , temperature , null );
32
+ mModule = new LlmModule ( modulePath , tokenizerPath , temperature );
47
33
}
48
34
49
35
/**
50
36
* Constructs a LLAMA Module for a model with given model path, tokenizer, temperature and data
51
37
* path.
52
38
*/
53
39
public LlamaModule (String modulePath , String tokenizerPath , float temperature , String dataPath ) {
54
- mHybridData = initHybrid ( MODEL_TYPE_TEXT , modulePath , tokenizerPath , temperature , dataPath );
40
+ mModule = new LlmModule ( modulePath , tokenizerPath , temperature , dataPath );
55
41
}
56
42
57
43
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
58
44
public LlamaModule (int modelType , String modulePath , String tokenizerPath , float temperature ) {
59
- mHybridData = initHybrid (modelType , modulePath , tokenizerPath , temperature , null );
45
+ mModule = new LlmModule (modelType , modulePath , tokenizerPath , temperature );
60
46
}
61
47
62
48
public void resetNative () {
63
- mHybridData .resetNative ();
49
+ mModule .resetNative ();
64
50
}
65
51
66
52
/**
@@ -70,7 +56,7 @@ public void resetNative() {
70
56
* @param llamaCallback callback object to receive results.
71
57
*/
72
58
public int generate (String prompt , LlamaCallback llamaCallback ) {
73
- return generate (prompt , DEFAULT_SEQ_LEN , llamaCallback , DEFAULT_ECHO );
59
+ return generate (null , 0 , 0 , 0 , prompt , DEFAULT_SEQ_LEN , llamaCallback , DEFAULT_ECHO );
74
60
}
75
61
76
62
/**
@@ -119,16 +105,35 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, bool
119
105
* @param llamaCallback callback object to receive results.
120
106
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
121
107
*/
122
- @ DoNotStrip
123
- public native int generate (
108
+ public int generate (
124
109
int [] image ,
125
110
int width ,
126
111
int height ,
127
112
int channels ,
128
113
String prompt ,
129
114
int seqLen ,
130
115
LlamaCallback llamaCallback ,
131
- boolean echo );
116
+ boolean echo ) {
117
+ return mModule .generate (
118
+ image ,
119
+ width ,
120
+ height ,
121
+ channels ,
122
+ prompt ,
123
+ seqLen ,
124
+ new LlmCallback () {
125
+ @ Override
126
+ public void onResult (String result ) {
127
+ llamaCallback .onResult (result );
128
+ }
129
+
130
+ @ Override
131
+ public void onStats (float tps ) {
132
+ llamaCallback .onStats (tps );
133
+ }
134
+ },
135
+ echo );
136
+ }
132
137
133
138
/**
134
139
* Prefill an LLaVA Module with the given images input.
@@ -142,17 +147,9 @@ public native int generate(
142
147
* @throws RuntimeException if the prefill failed
143
148
*/
144
149
public long prefillImages (int [] image , int width , int height , int channels , long startPos ) {
145
- long [] nativeResult = prefillImagesNative (image , width , height , channels , startPos );
146
- if (nativeResult [0 ] != 0 ) {
147
- throw new RuntimeException ("Prefill failed with error code: " + nativeResult [0 ]);
148
- }
149
- return nativeResult [1 ];
150
+ return mModule .prefillImages (image , width , height , channels , startPos );
150
151
}
151
152
152
- // returns a tuple of (status, updated startPos)
153
- private native long [] prefillImagesNative (
154
- int [] image , int width , int height , int channels , long startPos );
155
-
156
153
/**
157
154
* Prefill an LLaVA Module with the given text input.
158
155
*
@@ -165,16 +162,9 @@ private native long[] prefillImagesNative(
165
162
* @throws RuntimeException if the prefill failed
166
163
*/
167
164
public long prefillPrompt (String prompt , long startPos , int bos , int eos ) {
168
- long [] nativeResult = prefillPromptNative (prompt , startPos , bos , eos );
169
- if (nativeResult [0 ] != 0 ) {
170
- throw new RuntimeException ("Prefill failed with error code: " + nativeResult [0 ]);
171
- }
172
- return nativeResult [1 ];
165
+ return mModule .prefillPrompt (prompt , startPos , bos , eos );
173
166
}
174
167
175
- // returns a tuple of (status, updated startPos)
176
- private native long [] prefillPromptNative (String prompt , long startPos , int bos , int eos );
177
-
178
168
/**
179
169
* Generate tokens from the given prompt, starting from the given position.
180
170
*
@@ -185,14 +175,33 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
185
175
* @param echo indicate whether to echo the input prompt or not.
186
176
* @return The error code.
187
177
*/
188
- public native int generateFromPos (
189
- String prompt , int seqLen , long startPos , LlamaCallback callback , boolean echo );
178
+ public int generateFromPos (
179
+ String prompt , int seqLen , long startPos , LlamaCallback callback , boolean echo ) {
180
+ return mModule .generateFromPos (
181
+ prompt ,
182
+ seqLen ,
183
+ startPos ,
184
+ new LlmCallback () {
185
+ @ Override
186
+ public void onResult (String result ) {
187
+ callback .onResult (result );
188
+ }
189
+
190
+ @ Override
191
+ public void onStats (float tps ) {
192
+ callback .onStats (tps );
193
+ }
194
+ },
195
+ echo );
196
+ }
190
197
191
198
/** Stop current generate() before it finishes. */
192
- @ DoNotStrip
193
- public native void stop ();
199
+ public void stop () {
200
+ mModule .stop ();
201
+ }
194
202
195
203
/** Force loading the module. Otherwise the model is loaded during first generate(). */
196
- @ DoNotStrip
197
- public native int load ();
204
+ public int load () {
205
+ return mModule .load ();
206
+ }
198
207
}
0 commit comments