diff --git a/extension/benchmark/android/benchmark/README.md b/extension/benchmark/android/benchmark/README.md index f6731023f4..a5cdd22774 100644 --- a/extension/benchmark/android/benchmark/README.md +++ b/extension/benchmark/android/benchmark/README.md @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench ### Generic model ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ --es model_dir /data/local/tmp/minibench ``` ### LLM ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ --es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin ``` diff --git a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 index 4f8e72d21b..7d668e90c8 100644 --- a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 +++ b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 @@ -114,11 +114,11 @@ phases: adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180 if [ -n "$BIN_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin" elif [ -n "$MODEL_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model" else diff --git a/extension/benchmark/android/benchmark/app/build.gradle.kts b/extension/benchmark/android/benchmark/app/build.gradle.kts index 28dfc8ae49..4ee7efd1f9 100644 --- a/extension/benchmark/android/benchmark/app/build.gradle.kts +++ b/extension/benchmark/android/benchmark/app/build.gradle.kts @@ -6,7 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -plugins { id("com.android.application") } +plugins { id("com.android.application") + id("org.jetbrains.kotlin.android") +} android { namespace = "org.pytorch.minibench" @@ -29,8 +31,11 @@ android { } } compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" } } @@ -40,6 +45,7 @@ dependencies { implementation("com.facebook.fbjni:fbjni:0.5.1") implementation("com.google.code.gson:gson:2.8.6") implementation("org.json:json:20250107") + implementation("androidx.core:core-ktx:1.13.1") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml index 7f62c509d5..723829de98 100644 --- a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml @@ -21,14 +21,6 @@ - - - - - - diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 78830d5a54..5e1dd48926 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -10,9 +10,10 @@ import android.app.Activity; import android.content.Intent; -import android.os.AsyncTask; import android.os.Bundle; -import android.os.Debug; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; import android.system.ErrnoException; import android.system.Os; import com.google.gson.Gson; @@ -21,12 +22,22 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -import org.pytorch.executorch.Module; public class BenchmarkActivity extends Activity { + + File mModel; + int mNumIter; + int mNumWarmupIter; + String mTokenizerPath; + float mTemperature; + String mPrompt; + + HandlerThread mHandlerThread; + BenchmarkHandler mHandler; + + List mResult; + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) { int numIter = intent.getIntExtra("num_iter", 50); int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + float temperature = intent.getFloatExtra("temperature", 0.8f); + String prompt = intent.getStringExtra("prompt"); + + mModel = model; + mNumIter = numIter; + mNumWarmupIter = numWarmupIter; + mTokenizerPath = tokenizerPath; + mTemperature = temperature; + mPrompt = prompt; + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + mResult = new ArrayList<>(); - long pssIdle = Debug.getPss(); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); - // TODO: Format the string with a parsable format - Stats stats = new Stats(); + mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); + } - new AsyncTask() { - @Override - protected Void doInBackground(Void... voids) { + void writeResult() { + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mResult)); + } catch (IOException e) { + e.printStackTrace(); + } finally { + finish(); + } + } +} - // Record the time it takes to load the model and the forward method - stats.loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - stats.errorCode = module.loadMethod("forward"); - stats.loadEnd = System.nanoTime(); +class BenchmarkHandler extends Handler { + public static int MESSAGE_RUN_BENCHMARK = 1; + public static int MESSAGE_LLM_RUN_BENCHMARK = 2; - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } + ModelRunner mModelRunner; + BenchmarkActivity mBenchmarkActivity; - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - stats.latency.add(forwardMs); - } - return null; - } + LlmModelRunner mLlmModelRunner; + LlmBenchmark mLlmBenchmark; - @Override - protected void onPostExecute(Void aVoid) { - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(stats.latency); - int resultSize = stats.latency.size(); - List usedLatencyResults = - stats.latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, - "model_load_time(ms)", - (stats.loadEnd - stats.loadStart) * 1e-6, - 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } - } - }.execute(); + public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { + super(looper); + mModelRunner = new ModelRunner(); + mBenchmarkActivity = benchmarkActivity; } -} - -class Stats { - long loadStart; - long loadEnd; - List latency = new ArrayList<>(); - int errorCode = 0; @Override - public String toString() { - return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining("")); + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_RUN_BENCHMARK) { + mModelRunner.runBenchmark( + mBenchmarkActivity.mModel, + mBenchmarkActivity.mNumWarmupIter, + mBenchmarkActivity.mNumIter, + mBenchmarkActivity.mResult); + + if (mBenchmarkActivity.mTokenizerPath == null) { + mBenchmarkActivity.writeResult(); + } else { + this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); + } + } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { + mLlmBenchmark = + new LlmBenchmark( + mBenchmarkActivity, + mBenchmarkActivity.mModel.getPath(), + mBenchmarkActivity.mTokenizerPath, + mBenchmarkActivity.mPrompt, + mBenchmarkActivity.mTemperature, + mBenchmarkActivity.mResult); + } } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java similarity index 57% rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java index f6a894d6a1..0c0436d267 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java @@ -8,57 +8,33 @@ package org.pytorch.minibench; -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.system.ErrnoException; -import android.system.Os; import android.util.Log; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import org.json.JSONException; import org.json.JSONObject; -public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { - ModelRunner mModelRunner; +public class LlmBenchmark implements LlmModelRunnerCallback { + LlmModelRunner mLlmModelRunner; String mPrompt; StatsInfo mStatsInfo; - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - Intent intent = getIntent(); - - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - - float temperature = intent.getFloatExtra("temperature", 0.8f); - mPrompt = intent.getStringExtra("prompt"); - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - + List mResults; + BenchmarkActivity mActivity; + + LlmBenchmark( + BenchmarkActivity activity, + String modelFile, + String tokenizerPath, + String prompt, + float temperature, + List results) { + mResults = results; + mActivity = activity; mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = model.getName().replace(".pte", ""); - mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); + mPrompt = prompt; + mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); mStatsInfo.loadStart = System.nanoTime(); } @@ -72,7 +48,7 @@ public void onModelLoaded(int status) { return; } mStatsInfo.generateStart = System.nanoTime(); - mModelRunner.generate(mPrompt); + mLlmModelRunner.generate(mPrompt); } @Override @@ -99,33 +75,26 @@ public void onGenerationStopped() { final BenchmarkMetric.BenchmarkModel benchmarkModel = BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - final List results = new ArrayList<>(); // The list of metrics we have atm includes: // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); + mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); // Model load time - results.add( + mResults.add( new BenchmarkMetric( benchmarkModel, - "model_load_time(ms)", + "llm_model_load_time(ms)", (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, 0.0f)); // LLM generate time - results.add( + mResults.add( new BenchmarkMetric( benchmarkModel, "generate_time(ms)", (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, 0.0f)); // Token per second - results.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } + mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); + mActivity.writeResult(); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java new file mode 100644 index 0000000000..a1b434a37b --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +/** A helper class to handle all model running logic within this class. */ +public class LlmModelRunner implements LlmCallback { + LlmModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + LlmModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + LlmModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + LlmModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("LlmModelRunner"); + mHandlerThread.start(); + mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(String result) { + mCallback.onStats(result); + } +} + +class LlmModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final LlmModelRunner mLlmModelRunner; + + public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { + super(looper); + mLlmModelRunner = llmModelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mLlmModelRunner.mModule.load(); + mLlmModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + mLlmModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt similarity index 62% rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt index 8503d47ccc..cd2fecdf81 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt @@ -6,19 +6,21 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.minibench; + +package org.pytorch.minibench + /** * A helper interface within the app for MainActivity and Benchmarking to handle callback from * ModelRunner. */ -public interface ModelRunnerCallback { +interface LlmModelRunnerCallback { - void onModelLoaded(int status); + fun onModelLoaded(status: Int) - void onTokenGenerated(String token); + fun onTokenGenerated(token: String) - void onStats(String result); + fun onStats(result: String) - void onGenerationStopped(); + fun onGenerationStopped() } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 0a75b47f3a..3913a8d76f 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -8,90 +8,70 @@ package org.pytorch.minibench; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class ModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - ModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - +import android.os.Debug; +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import org.pytorch.executorch.Module; + +public class ModelRunner { /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback + * @return list of #BenchmarkMetric */ - ModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - ModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } + public void runBenchmark( + File model, int numWarmupIter, int numIter, List results) { + long pssIdle = Debug.getPss(); - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class ModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; + List latency = new ArrayList<>(); - private final ModelRunner mModelRunner; + long loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + int errorCode = module.loadMethod("forward"); + long loadEnd = System.nanoTime(); - public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { - super(looper); - mModelRunner = modelRunner; - } + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mModelRunner.mModule.load(); - mModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mModelRunner.mModule.generate((String) msg.obj, mModelRunner); - mModelRunner.mCallback.onGenerationStopped(); + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + latency.add(forwardMs); } + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(latency); + int resultSize = latency.size(); + List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); + + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); } } diff --git a/extension/benchmark/android/benchmark/build.gradle.kts b/extension/benchmark/android/benchmark/build.gradle.kts index ac625be8e0..b1ed5127df 100644 --- a/extension/benchmark/android/benchmark/build.gradle.kts +++ b/extension/benchmark/android/benchmark/build.gradle.kts @@ -7,4 +7,6 @@ */ // Top-level build file where you can add configuration options common to all sub-projects/modules. -plugins { id("com.android.application") version "8.1.0" apply false } +plugins { id("com.android.application") version "8.1.0" apply false + id("org.jetbrains.kotlin.android") version "2.1.10" apply false +}