diff --git a/extension/benchmark/android/benchmark/README.md b/extension/benchmark/android/benchmark/README.md
index f6731023f4..a5cdd22774 100644
--- a/extension/benchmark/android/benchmark/README.md
+++ b/extension/benchmark/android/benchmark/README.md
@@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench
### Generic model
```
-adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
+adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench
```
### LLM
```
-adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
+adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
```
diff --git a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2
index 4f8e72d21b..7d668e90c8 100644
--- a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2
+++ b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2
@@ -114,11 +114,11 @@ phases:
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180
if [ -n "$BIN_FOUND" ]; then
- adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
+ adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
elif [ -n "$MODEL_FOUND" ]; then
- adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
+ adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
else
diff --git a/extension/benchmark/android/benchmark/app/build.gradle.kts b/extension/benchmark/android/benchmark/app/build.gradle.kts
index 28dfc8ae49..4ee7efd1f9 100644
--- a/extension/benchmark/android/benchmark/app/build.gradle.kts
+++ b/extension/benchmark/android/benchmark/app/build.gradle.kts
@@ -6,7 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/
-plugins { id("com.android.application") }
+plugins { id("com.android.application")
+ id("org.jetbrains.kotlin.android")
+}
android {
namespace = "org.pytorch.minibench"
@@ -29,8 +31,11 @@ android {
}
}
compileOptions {
- sourceCompatibility = JavaVersion.VERSION_1_8
- targetCompatibility = JavaVersion.VERSION_1_8
+ sourceCompatibility = JavaVersion.VERSION_17
+ targetCompatibility = JavaVersion.VERSION_17
+ }
+ kotlinOptions {
+ jvmTarget = "17"
}
}
@@ -40,6 +45,7 @@ dependencies {
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
+ implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
diff --git a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml
index 7f62c509d5..723829de98 100644
--- a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml
+++ b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml
@@ -21,14 +21,6 @@
-
-
-
-
-
-
diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java
index 78830d5a54..5e1dd48926 100644
--- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java
+++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java
@@ -10,9 +10,10 @@
import android.app.Activity;
import android.content.Intent;
-import android.os.AsyncTask;
import android.os.Bundle;
-import android.os.Debug;
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.Looper;
import android.system.ErrnoException;
import android.system.Os;
import com.google.gson.Gson;
@@ -21,12 +22,22 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Collections;
import java.util.List;
-import java.util.stream.Collectors;
-import org.pytorch.executorch.Module;
public class BenchmarkActivity extends Activity {
+
+ File mModel;
+ int mNumIter;
+ int mNumWarmupIter;
+ String mTokenizerPath;
+ float mTemperature;
+ String mPrompt;
+
+ HandlerThread mHandlerThread;
+ BenchmarkHandler mHandler;
+
+ List mResult;
+
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
@@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {
int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
+ String tokenizerPath = intent.getStringExtra("tokenizer_path");
+ float temperature = intent.getFloatExtra("temperature", 0.8f);
+ String prompt = intent.getStringExtra("prompt");
+
+ mModel = model;
+ mNumIter = numIter;
+ mNumWarmupIter = numWarmupIter;
+ mTokenizerPath = tokenizerPath;
+ mTemperature = temperature;
+ mPrompt = prompt;
+ if (mPrompt == null) {
+ mPrompt = "The ultimate answer";
+ }
+ mResult = new ArrayList<>();
- long pssIdle = Debug.getPss();
+ mHandlerThread = new HandlerThread("ModelRunner");
+ mHandlerThread.start();
+ mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);
- // TODO: Format the string with a parsable format
- Stats stats = new Stats();
+ mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
+ }
- new AsyncTask() {
- @Override
- protected Void doInBackground(Void... voids) {
+ void writeResult() {
+ try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
+ Gson gson = new Gson();
+ writer.write(gson.toJson(mResult));
+ } catch (IOException e) {
+ e.printStackTrace();
+ } finally {
+ finish();
+ }
+ }
+}
- // Record the time it takes to load the model and the forward method
- stats.loadStart = System.nanoTime();
- Module module = Module.load(model.getPath());
- stats.errorCode = module.loadMethod("forward");
- stats.loadEnd = System.nanoTime();
+class BenchmarkHandler extends Handler {
+ public static int MESSAGE_RUN_BENCHMARK = 1;
+ public static int MESSAGE_LLM_RUN_BENCHMARK = 2;
- for (int i = 0; i < numWarmupIter; i++) {
- module.forward();
- }
+ ModelRunner mModelRunner;
+ BenchmarkActivity mBenchmarkActivity;
- for (int i = 0; i < numIter; i++) {
- long start = System.nanoTime();
- module.forward();
- double forwardMs = (System.nanoTime() - start) * 1e-6;
- stats.latency.add(forwardMs);
- }
- return null;
- }
+ LlmModelRunner mLlmModelRunner;
+ LlmBenchmark mLlmBenchmark;
- @Override
- protected void onPostExecute(Void aVoid) {
-
- final BenchmarkMetric.BenchmarkModel benchmarkModel =
- BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
- final List results = new ArrayList<>();
- // The list of metrics we have atm includes:
- // Avg inference latency after N iterations
- // Currently the result has large variance from outliers, so only use
- // 80% samples in the middle (trimmean 0.2)
- Collections.sort(stats.latency);
- int resultSize = stats.latency.size();
- List usedLatencyResults =
- stats.latency.subList(resultSize / 10, resultSize * 9 / 10);
-
- results.add(
- new BenchmarkMetric(
- benchmarkModel,
- "avg_inference_latency(ms)",
- stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
- 0.0f));
- results.add(
- new BenchmarkMetric(
- benchmarkModel,
- "trimmean_inference_latency(ms)",
- usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
- 0.0f));
- // Model load time
- results.add(
- new BenchmarkMetric(
- benchmarkModel,
- "model_load_time(ms)",
- (stats.loadEnd - stats.loadStart) * 1e-6,
- 0.0f));
- // Load status
- results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
- // RAM PSS usage
- results.add(
- new BenchmarkMetric(
- benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
-
- try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
- Gson gson = new Gson();
- writer.write(gson.toJson(results));
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- }.execute();
+ public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
+ super(looper);
+ mModelRunner = new ModelRunner();
+ mBenchmarkActivity = benchmarkActivity;
}
-}
-
-class Stats {
- long loadStart;
- long loadEnd;
- List latency = new ArrayList<>();
- int errorCode = 0;
@Override
- public String toString() {
- return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
+ public void handleMessage(android.os.Message msg) {
+ if (msg.what == MESSAGE_RUN_BENCHMARK) {
+ mModelRunner.runBenchmark(
+ mBenchmarkActivity.mModel,
+ mBenchmarkActivity.mNumWarmupIter,
+ mBenchmarkActivity.mNumIter,
+ mBenchmarkActivity.mResult);
+
+ if (mBenchmarkActivity.mTokenizerPath == null) {
+ mBenchmarkActivity.writeResult();
+ } else {
+ this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
+ }
+ } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
+ mLlmBenchmark =
+ new LlmBenchmark(
+ mBenchmarkActivity,
+ mBenchmarkActivity.mModel.getPath(),
+ mBenchmarkActivity.mTokenizerPath,
+ mBenchmarkActivity.mPrompt,
+ mBenchmarkActivity.mTemperature,
+ mBenchmarkActivity.mResult);
+ }
}
}
diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java
similarity index 57%
rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java
rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java
index f6a894d6a1..0c0436d267 100644
--- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java
+++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java
@@ -8,57 +8,33 @@
package org.pytorch.minibench;
-import android.app.Activity;
-import android.content.Intent;
-import android.os.Bundle;
-import android.system.ErrnoException;
-import android.system.Os;
import android.util.Log;
-import com.google.gson.Gson;
-import java.io.File;
-import java.io.FileWriter;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import org.json.JSONException;
import org.json.JSONObject;
-public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
- ModelRunner mModelRunner;
+public class LlmBenchmark implements LlmModelRunnerCallback {
+ LlmModelRunner mLlmModelRunner;
String mPrompt;
StatsInfo mStatsInfo;
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
-
- Intent intent = getIntent();
-
- File modelDir = new File(intent.getStringExtra("model_dir"));
- File model =
- Arrays.stream(modelDir.listFiles())
- .filter(file -> file.getName().endsWith(".pte"))
- .findFirst()
- .get();
- String tokenizerPath = intent.getStringExtra("tokenizer_path");
-
- float temperature = intent.getFloatExtra("temperature", 0.8f);
- mPrompt = intent.getStringExtra("prompt");
- if (mPrompt == null) {
- mPrompt = "The ultimate answer";
- }
-
- try {
- Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
- } catch (ErrnoException e) {
- finish();
- }
-
+ List mResults;
+ BenchmarkActivity mActivity;
+
+ LlmBenchmark(
+ BenchmarkActivity activity,
+ String modelFile,
+ String tokenizerPath,
+ String prompt,
+ float temperature,
+ List results) {
+ mResults = results;
+ mActivity = activity;
mStatsInfo = new StatsInfo();
- mStatsInfo.modelName = model.getName().replace(".pte", "");
- mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
+ mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", "");
+ mPrompt = prompt;
+ mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this);
mStatsInfo.loadStart = System.nanoTime();
}
@@ -72,7 +48,7 @@ public void onModelLoaded(int status) {
return;
}
mStatsInfo.generateStart = System.nanoTime();
- mModelRunner.generate(mPrompt);
+ mLlmModelRunner.generate(mPrompt);
}
@Override
@@ -99,33 +75,26 @@ public void onGenerationStopped() {
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName);
- final List results = new ArrayList<>();
// The list of metrics we have atm includes:
// Load status
- results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0));
+ mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0));
// Model load time
- results.add(
+ mResults.add(
new BenchmarkMetric(
benchmarkModel,
- "model_load_time(ms)",
+ "llm_model_load_time(ms)",
(mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6,
0.0f));
// LLM generate time
- results.add(
+ mResults.add(
new BenchmarkMetric(
benchmarkModel,
"generate_time(ms)",
(mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6,
0.0f));
// Token per second
- results.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f));
-
- try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
- Gson gson = new Gson();
- writer.write(gson.toJson(results));
- } catch (IOException e) {
- e.printStackTrace();
- }
+ mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f));
+ mActivity.writeResult();
}
}
diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java
new file mode 100644
index 0000000000..a1b434a37b
--- /dev/null
+++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java
@@ -0,0 +1,97 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+package org.pytorch.minibench;
+
+import android.os.Handler;
+import android.os.HandlerThread;
+import android.os.Looper;
+import android.os.Message;
+import org.pytorch.executorch.extension.llm.LlmCallback;
+import org.pytorch.executorch.extension.llm.LlmModule;
+
+/** A helper class to handle all model running logic within this class. */
+public class LlmModelRunner implements LlmCallback {
+ LlmModule mModule = null;
+
+ String mModelFilePath = "";
+ String mTokenizerFilePath = "";
+
+ LlmModelRunnerCallback mCallback = null;
+
+ HandlerThread mHandlerThread = null;
+ Handler mHandler = null;
+
+ /**
+ * ] Helper class to separate between UI logic and model runner logic. Automatically handle
+ * generate() request on worker thread.
+ *
+ * @param modelFilePath
+ * @param tokenizerFilePath
+ * @param callback
+ */
+ LlmModelRunner(
+ String modelFilePath,
+ String tokenizerFilePath,
+ float temperature,
+ LlmModelRunnerCallback callback) {
+ mModelFilePath = modelFilePath;
+ mTokenizerFilePath = tokenizerFilePath;
+ mCallback = callback;
+
+ mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f);
+ mHandlerThread = new HandlerThread("LlmModelRunner");
+ mHandlerThread.start();
+ mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this);
+
+ mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL);
+ }
+
+ int generate(String prompt) {
+ Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt);
+ msg.sendToTarget();
+ return 0;
+ }
+
+ void stop() {
+ mModule.stop();
+ }
+
+ @Override
+ public void onResult(String result) {
+ mCallback.onTokenGenerated(result);
+ }
+
+ @Override
+ public void onStats(String result) {
+ mCallback.onStats(result);
+ }
+}
+
+class LlmModelRunnerHandler extends Handler {
+ public static int MESSAGE_LOAD_MODEL = 1;
+ public static int MESSAGE_GENERATE = 2;
+
+ private final LlmModelRunner mLlmModelRunner;
+
+ public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) {
+ super(looper);
+ mLlmModelRunner = llmModelRunner;
+ }
+
+ @Override
+ public void handleMessage(android.os.Message msg) {
+ if (msg.what == MESSAGE_LOAD_MODEL) {
+ int status = mLlmModelRunner.mModule.load();
+ mLlmModelRunner.mCallback.onModelLoaded(status);
+ } else if (msg.what == MESSAGE_GENERATE) {
+ mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner);
+ mLlmModelRunner.mCallback.onGenerationStopped();
+ }
+ }
+}
diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt
similarity index 62%
rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java
rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt
index 8503d47ccc..cd2fecdf81 100644
--- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java
+++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt
@@ -6,19 +6,21 @@
* LICENSE file in the root directory of this source tree.
*/
-package org.pytorch.minibench;
+
+package org.pytorch.minibench
+
/**
* A helper interface within the app for MainActivity and Benchmarking to handle callback from
* ModelRunner.
*/
-public interface ModelRunnerCallback {
+interface LlmModelRunnerCallback {
- void onModelLoaded(int status);
+ fun onModelLoaded(status: Int)
- void onTokenGenerated(String token);
+ fun onTokenGenerated(token: String)
- void onStats(String result);
+ fun onStats(result: String)
- void onGenerationStopped();
+ fun onGenerationStopped()
}
diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java
index 0a75b47f3a..3913a8d76f 100644
--- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java
+++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java
@@ -8,90 +8,70 @@
package org.pytorch.minibench;
-import android.os.Handler;
-import android.os.HandlerThread;
-import android.os.Looper;
-import android.os.Message;
-import org.pytorch.executorch.extension.llm.LlmCallback;
-import org.pytorch.executorch.extension.llm.LlmModule;
-
-/** A helper class to handle all model running logic within this class. */
-public class ModelRunner implements LlmCallback {
- LlmModule mModule = null;
-
- String mModelFilePath = "";
- String mTokenizerFilePath = "";
-
- ModelRunnerCallback mCallback = null;
-
- HandlerThread mHandlerThread = null;
- Handler mHandler = null;
-
+import android.os.Debug;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.pytorch.executorch.Module;
+
+public class ModelRunner {
/**
- * ] Helper class to separate between UI logic and model runner logic. Automatically handle
- * generate() request on worker thread.
- *
- * @param modelFilePath
- * @param tokenizerFilePath
- * @param callback
+ * @return list of #BenchmarkMetric
*/
- ModelRunner(
- String modelFilePath,
- String tokenizerFilePath,
- float temperature,
- ModelRunnerCallback callback) {
- mModelFilePath = modelFilePath;
- mTokenizerFilePath = tokenizerFilePath;
- mCallback = callback;
-
- mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f);
- mHandlerThread = new HandlerThread("ModelRunner");
- mHandlerThread.start();
- mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);
-
- mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL);
- }
-
- int generate(String prompt) {
- Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt);
- msg.sendToTarget();
- return 0;
- }
+ public void runBenchmark(
+ File model, int numWarmupIter, int numIter, List results) {
+ long pssIdle = Debug.getPss();
- void stop() {
- mModule.stop();
- }
-
- @Override
- public void onResult(String result) {
- mCallback.onTokenGenerated(result);
- }
-
- @Override
- public void onStats(String result) {
- mCallback.onStats(result);
- }
-}
-
-class ModelRunnerHandler extends Handler {
- public static int MESSAGE_LOAD_MODEL = 1;
- public static int MESSAGE_GENERATE = 2;
+ List latency = new ArrayList<>();
- private final ModelRunner mModelRunner;
+ long loadStart = System.nanoTime();
+ Module module = Module.load(model.getPath());
+ int errorCode = module.loadMethod("forward");
+ long loadEnd = System.nanoTime();
- public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) {
- super(looper);
- mModelRunner = modelRunner;
- }
+ for (int i = 0; i < numWarmupIter; i++) {
+ module.forward();
+ }
- @Override
- public void handleMessage(android.os.Message msg) {
- if (msg.what == MESSAGE_LOAD_MODEL) {
- int status = mModelRunner.mModule.load();
- mModelRunner.mCallback.onModelLoaded(status);
- } else if (msg.what == MESSAGE_GENERATE) {
- mModelRunner.mModule.generate((String) msg.obj, mModelRunner);
- mModelRunner.mCallback.onGenerationStopped();
+ for (int i = 0; i < numIter; i++) {
+ long start = System.nanoTime();
+ module.forward();
+ double forwardMs = (System.nanoTime() - start) * 1e-6;
+ latency.add(forwardMs);
}
+
+ final BenchmarkMetric.BenchmarkModel benchmarkModel =
+ BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
+ // The list of metrics we have atm includes:
+ // Avg inference latency after N iterations
+ // Currently the result has large variance from outliers, so only use
+ // 80% samples in the middle (trimmean 0.2)
+ Collections.sort(latency);
+ int resultSize = latency.size();
+ List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10);
+
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel,
+ "avg_inference_latency(ms)",
+ latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
+ 0.0f));
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel,
+ "trimmean_inference_latency(ms)",
+ usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
+ 0.0f));
+ // Model load time
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f));
+ // Load status
+ results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0));
+ // RAM PSS usage
+ results.add(
+ new BenchmarkMetric(
+ benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
}
}
diff --git a/extension/benchmark/android/benchmark/build.gradle.kts b/extension/benchmark/android/benchmark/build.gradle.kts
index ac625be8e0..b1ed5127df 100644
--- a/extension/benchmark/android/benchmark/build.gradle.kts
+++ b/extension/benchmark/android/benchmark/build.gradle.kts
@@ -7,4 +7,6 @@
*/
// Top-level build file where you can add configuration options common to all sub-projects/modules.
-plugins { id("com.android.application") version "8.1.0" apply false }
+plugins { id("com.android.application") version "8.1.0" apply false
+ id("org.jetbrains.kotlin.android") version "2.1.10" apply false
+}