From 428a472b5e1dde8f1066f28a677069373ad7a708 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 22 Apr 2025 17:06:09 -0700 Subject: [PATCH 1/3] Minibench refactor Allow running generic model benchmark before LLM --- .../benchmark/android/benchmark/README.md | 4 +- .../android-llm-device-farm-test-spec.yml.j2 | 4 +- .../android/benchmark/app/build.gradle.kts | 12 +- .../app/src/main/AndroidManifest.xml | 8 - .../pytorch/minibench/BenchmarkActivity.java | 188 ++++++++---------- .../pytorch/minibench/BenchmarkMetric.java | 95 ++++----- .../org/pytorch/minibench/LlmBenchmark.java | 120 +++++++++++ .../minibench/LlmBenchmarkActivity.java | 154 -------------- .../org/pytorch/minibench/LlmModelRunner.java | 100 ++++++++++ ...allback.java => LlmModelRunnerCallback.kt} | 14 +- .../org/pytorch/minibench/ModelRunner.java | 155 +++++++-------- .../android/benchmark/build.gradle.kts | 4 +- 12 files changed, 449 insertions(+), 409 deletions(-) create mode 100644 extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java delete mode 100644 extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java create mode 100644 extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java rename extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/{ModelRunnerCallback.java => LlmModelRunnerCallback.kt} (62%) diff --git a/extension/benchmark/android/benchmark/README.md b/extension/benchmark/android/benchmark/README.md index f6731023f4..a5cdd22774 100644 --- a/extension/benchmark/android/benchmark/README.md +++ b/extension/benchmark/android/benchmark/README.md @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench ### Generic model ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ --es model_dir /data/local/tmp/minibench ``` ### LLM ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ --es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin ``` diff --git a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 index 4f8e72d21b..7d668e90c8 100644 --- a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 +++ b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 @@ -114,11 +114,11 @@ phases: adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180 if [ -n "$BIN_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin" elif [ -n "$MODEL_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model" else diff --git a/extension/benchmark/android/benchmark/app/build.gradle.kts b/extension/benchmark/android/benchmark/app/build.gradle.kts index 28dfc8ae49..4ee7efd1f9 100644 --- a/extension/benchmark/android/benchmark/app/build.gradle.kts +++ b/extension/benchmark/android/benchmark/app/build.gradle.kts @@ -6,7 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -plugins { id("com.android.application") } +plugins { id("com.android.application") + id("org.jetbrains.kotlin.android") +} android { namespace = "org.pytorch.minibench" @@ -29,8 +31,11 @@ android { } } compileOptions { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlinOptions { + jvmTarget = "17" } } @@ -40,6 +45,7 @@ dependencies { implementation("com.facebook.fbjni:fbjni:0.5.1") implementation("com.google.code.gson:gson:2.8.6") implementation("org.json:json:20250107") + implementation("androidx.core:core-ktx:1.13.1") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml index 7f62c509d5..723829de98 100644 --- a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml @@ -21,14 +21,6 @@ - - - - - - diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 78830d5a54..54eafaa555 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -10,132 +10,118 @@ import android.app.Activity; import android.content.Intent; -import android.os.AsyncTask; import android.os.Bundle; -import android.os.Debug; +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; import android.system.ErrnoException; import android.system.Os; + import com.google.gson.Gson; + import java.io.File; import java.io.FileWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; -import org.pytorch.executorch.Module; public class BenchmarkActivity extends Activity { - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - Intent intent = getIntent(); - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - int numIter = intent.getIntExtra("num_iter", 50); - int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); + File mModel; + int mNumIter; + int mNumWarmupIter; + String mTokenizerPath; + float mTemperature; + String mPrompt; - long pssIdle = Debug.getPss(); + HandlerThread mHandlerThread; + BenchmarkHandler mHandler; - // TODO: Format the string with a parsable format - Stats stats = new Stats(); + List mResult; - new AsyncTask() { - @Override - protected Void doInBackground(Void... voids) { + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); - // Record the time it takes to load the model and the forward method - stats.loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - stats.errorCode = module.loadMethod("forward"); - stats.loadEnd = System.nanoTime(); - - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); + try { + Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + } catch (ErrnoException e) { + finish(); } - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - stats.latency.add(forwardMs); + Intent intent = getIntent(); + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + + int numIter = intent.getIntExtra("num_iter", 50); + int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + float temperature = intent.getFloatExtra("temperature", 0.8f); + String prompt = intent.getStringExtra("prompt"); + + mModel = model; + mNumIter = numIter; + mNumWarmupIter = numWarmupIter; + mTokenizerPath = tokenizerPath; + mTemperature = temperature; + mPrompt = prompt; + if (mPrompt == null) { + mPrompt = "The ultimate answer"; } - return null; - } - - @Override - protected void onPostExecute(Void aVoid) { - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(stats.latency); - int resultSize = stats.latency.size(); - List usedLatencyResults = - stats.latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, - "model_load_time(ms)", - (stats.loadEnd - stats.loadStart) * 1e-6, - 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + mResult = new ArrayList<>(); + + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); + } + void writeResult() { try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); + Gson gson = new Gson(); + writer.write(gson.toJson(mResult)); } catch (IOException e) { - e.printStackTrace(); + e.printStackTrace(); + } finally { + finish(); } - } - }.execute(); - } + } } -class Stats { - long loadStart; - long loadEnd; - List latency = new ArrayList<>(); - int errorCode = 0; +class BenchmarkHandler extends Handler { + public static int MESSAGE_RUN_BENCHMARK = 1; + public static int MESSAGE_LLM_RUN_BENCHMARK = 2; + + ModelRunner mModelRunner; + BenchmarkActivity mBenchmarkActivity; - @Override - public String toString() { - return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining("")); - } + LlmModelRunner mLlmModelRunner; + LlmBenchmark mLlmBenchmark; + + public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { + super(looper); + mModelRunner = new ModelRunner(); + mBenchmarkActivity = benchmarkActivity; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_RUN_BENCHMARK) { + mModelRunner.runBenchmark(mBenchmarkActivity.mModel, mBenchmarkActivity.mNumWarmupIter, mBenchmarkActivity.mNumIter, mBenchmarkActivity.mResult); + + if (mBenchmarkActivity.mTokenizerPath == null) { + mBenchmarkActivity.writeResult(); + } else { + this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); + } + } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { + mLlmBenchmark = new LlmBenchmark(mBenchmarkActivity, mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, mBenchmarkActivity.mResult); + } + } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java index 66ab50550a..0c09614780 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java @@ -10,65 +10,66 @@ import android.app.ActivityManager; import android.os.Build; + import java.util.regex.Matcher; import java.util.regex.Pattern; class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; + public static class BenchmarkModel { + // The model name, i.e. stories110M + String name; + String backend; + String quantization; - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; + public BenchmarkModel(final String name, final String backend, final String quantization) { + this.name = name; + this.backend = backend; + this.quantization = quantization; + } } - } - BenchmarkModel benchmarkModel; + BenchmarkModel benchmarkModel; - // The metric name, i.e. TPS - String metric; + // The metric name, i.e. TPS + String metric; - // The actual value and the option target value - double actualValue; - double targetValue; + // The actual value and the option target value + double actualValue; + double targetValue; - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } + public static class DeviceInfo { + // Let's see which information we want to include here + final String device = Build.BRAND; + // The phone model and Android release version + final String arch = Build.MODEL; + final String os = "Android " + Build.VERSION.RELEASE; + final long totalMem = new ActivityManager.MemoryInfo().totalMem; + final long availMem = new ActivityManager.MemoryInfo().availMem; + } - DeviceInfo deviceInfo = new DeviceInfo(); + DeviceInfo deviceInfo = new DeviceInfo(); - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } + public BenchmarkMetric( + final BenchmarkModel benchmarkModel, + final String metric, + final double actualValue, + final double targetValue) { + this.benchmarkModel = benchmarkModel; + this.metric = metric; + this.actualValue = actualValue; + this.targetValue = targetValue; + } - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { + final Matcher m = + Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); + if (m.matches()) { + return new BenchmarkMetric.BenchmarkModel( + m.group("name"), m.group("backend"), m.group("quantization")); + } else { + return new BenchmarkMetric.BenchmarkModel(model, "", ""); + } } - } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java new file mode 100644 index 0000000000..c29002ea9c --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.util.Log; + +import org.json.JSONException; +import org.json.JSONObject; + +import java.util.List; + +public class LlmBenchmark implements LlmModelRunnerCallback { + LlmModelRunner mLlmModelRunner; + + String mPrompt; + StatsInfo mStatsInfo; + + List mResults; + BenchmarkActivity mActivity; + + LlmBenchmark(BenchmarkActivity activity, String modelFile, String tokenizerPath, String prompt, float temperature, List results) { + mResults = results; + mActivity = activity; + mStatsInfo = new StatsInfo(); + mStatsInfo.modelName = modelFile.replace(".pte", ""); + mPrompt = prompt; + mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); + mStatsInfo.loadStart = System.nanoTime(); + } + + @Override + public void onModelLoaded(int status) { + mStatsInfo.loadEnd = System.nanoTime(); + mStatsInfo.loadStatus = status; + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsInfo.generateStart = System.nanoTime(); + mLlmModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) { + } + + @Override + public void onStats(String stats) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + mStatsInfo.tps = tps; + } catch (JSONException e) { + Log.e("LLM", "Error parsing JSON: " + e.getMessage()); + } + } + + @Override + public void onGenerationStopped() { + mStatsInfo.generateEnd = System.nanoTime(); + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); + // The list of metrics we have atm includes: + // Load status + mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); + // Model load time + mResults.add( + new BenchmarkMetric( + benchmarkModel, + "llm_model_load_time(ms)", + (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, + 0.0f)); + // LLM generate time + mResults.add( + new BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, + 0.0f)); + // Token per second + mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); + mActivity.writeResult(); + } +} + +class StatsInfo { + int loadStatus; + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + float tps; + String modelName; + + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tps; + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java deleted file mode 100644 index f6a894d6a1..0000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.app.Activity; -import android.content.Intent; -import android.os.Bundle; -import android.system.ErrnoException; -import android.system.Os; -import android.util.Log; -import com.google.gson.Gson; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.json.JSONException; -import org.json.JSONObject; - -public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { - ModelRunner mModelRunner; - - String mPrompt; - StatsInfo mStatsInfo; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - Intent intent = getIntent(); - - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - - float temperature = intent.getFloatExtra("temperature", 0.8f); - mPrompt = intent.getStringExtra("prompt"); - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = model.getName().replace(".pte", ""); - mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); - mStatsInfo.loadStart = System.nanoTime(); - } - - @Override - public void onModelLoaded(int status) { - mStatsInfo.loadEnd = System.nanoTime(); - mStatsInfo.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsInfo.generateStart = System.nanoTime(); - mModelRunner.generate(mPrompt); - } - - @Override - public void onTokenGenerated(String token) {} - - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - mStatsInfo.tps = tps; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } - } - - @Override - public void onGenerationStopped() { - mStatsInfo.generateEnd = System.nanoTime(); - - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - final List results = new ArrayList<>(); - // The list of metrics we have atm includes: - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, - "model_load_time(ms)", - (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - results.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, - 0.0f)); - // Token per second - results.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(results)); - } catch (IOException e) { - e.printStackTrace(); - } - } -} - -class StatsInfo { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - float tps; - String modelName; - - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tps; - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java new file mode 100644 index 0000000000..ffa998665c --- /dev/null +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; + +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +/** + * A helper class to handle all model running logic within this class. + */ +public class LlmModelRunner implements LlmCallback { + LlmModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + LlmModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + LlmModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + LlmModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("LlmModelRunner"); + mHandlerThread.start(); + mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(String result) { + mCallback.onStats(result); + } +} + +class LlmModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final LlmModelRunner mLlmModelRunner; + + public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { + super(looper); + mLlmModelRunner = llmModelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mLlmModelRunner.mModule.load(); + mLlmModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + mLlmModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt similarity index 62% rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt index 8503d47ccc..cd2fecdf81 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt @@ -6,19 +6,21 @@ * LICENSE file in the root directory of this source tree. */ -package org.pytorch.minibench; + +package org.pytorch.minibench + /** * A helper interface within the app for MainActivity and Benchmarking to handle callback from * ModelRunner. */ -public interface ModelRunnerCallback { +interface LlmModelRunnerCallback { - void onModelLoaded(int status); + fun onModelLoaded(status: Int) - void onTokenGenerated(String token); + fun onTokenGenerated(token: String) - void onStats(String result); + fun onStats(result: String) - void onGenerationStopped(); + fun onGenerationStopped() } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 0a75b47f3a..190bf3284f 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -8,90 +8,75 @@ package org.pytorch.minibench; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class ModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - ModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - ModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - ModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class ModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final ModelRunner mModelRunner; - - public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { - super(looper); - mModelRunner = modelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mModelRunner.mModule.load(); - mModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mModelRunner.mModule.generate((String) msg.obj, mModelRunner); - mModelRunner.mCallback.onGenerationStopped(); +import android.os.Debug; + +import org.pytorch.executorch.Module; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class ModelRunner { + /** + * @return list of #BenchmarkMetric + */ + public void runBenchmark(File model, int numWarmupIter, int numIter, List results) { + long pssIdle = Debug.getPss(); + + List latency = new ArrayList<>(); + + long loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + int errorCode = module.loadMethod("forward"); + long loadEnd = System.nanoTime(); + + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } + + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + latency.add(forwardMs); + } + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(latency); + int resultSize = latency.size(); + List usedLatencyResults = + latency.subList(resultSize / 10, resultSize * 9 / 10); + + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (loadEnd - loadStart) * 1e-6, + 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); } - } } diff --git a/extension/benchmark/android/benchmark/build.gradle.kts b/extension/benchmark/android/benchmark/build.gradle.kts index ac625be8e0..b1ed5127df 100644 --- a/extension/benchmark/android/benchmark/build.gradle.kts +++ b/extension/benchmark/android/benchmark/build.gradle.kts @@ -7,4 +7,6 @@ */ // Top-level build file where you can add configuration options common to all sub-projects/modules. -plugins { id("com.android.application") version "8.1.0" apply false } +plugins { id("com.android.application") version "8.1.0" apply false + id("org.jetbrains.kotlin.android") version "2.1.10" apply false +} From 740bbac0e6b1697526963c61cac746bdea8d5988 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 22 Apr 2025 21:21:41 -0700 Subject: [PATCH 2/3] Linter --- .../pytorch/minibench/BenchmarkActivity.java | 190 +++++++++--------- .../pytorch/minibench/BenchmarkMetric.java | 95 +++++---- .../org/pytorch/minibench/LlmBenchmark.java | 183 ++++++++--------- .../org/pytorch/minibench/LlmModelRunner.java | 151 +++++++------- .../org/pytorch/minibench/ModelRunner.java | 112 +++++------ 5 files changed, 368 insertions(+), 363 deletions(-) diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 54eafaa555..b956375688 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -16,9 +16,8 @@ import android.os.Looper; import android.system.ErrnoException; import android.system.Os; - +import android.util.Log; import com.google.gson.Gson; - import java.io.File; import java.io.FileWriter; import java.io.IOException; @@ -28,100 +27,111 @@ public class BenchmarkActivity extends Activity { - File mModel; - int mNumIter; - int mNumWarmupIter; - String mTokenizerPath; - float mTemperature; - String mPrompt; - - HandlerThread mHandlerThread; - BenchmarkHandler mHandler; - - List mResult; - - @Override - protected void onCreate(Bundle savedInstanceState) { - super.onCreate(savedInstanceState); - - try { - Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); - } catch (ErrnoException e) { - finish(); - } - - Intent intent = getIntent(); - File modelDir = new File(intent.getStringExtra("model_dir")); - File model = - Arrays.stream(modelDir.listFiles()) - .filter(file -> file.getName().endsWith(".pte")) - .findFirst() - .get(); - - int numIter = intent.getIntExtra("num_iter", 50); - int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - float temperature = intent.getFloatExtra("temperature", 0.8f); - String prompt = intent.getStringExtra("prompt"); - - mModel = model; - mNumIter = numIter; - mNumWarmupIter = numWarmupIter; - mTokenizerPath = tokenizerPath; - mTemperature = temperature; - mPrompt = prompt; - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - mResult = new ArrayList<>(); - - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); - } + File mModel; + int mNumIter; + int mNumWarmupIter; + String mTokenizerPath; + float mTemperature; + String mPrompt; - void writeResult() { - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mResult)); - } catch (IOException e) { - e.printStackTrace(); - } finally { - finish(); - } - } -} + HandlerThread mHandlerThread; + BenchmarkHandler mHandler; -class BenchmarkHandler extends Handler { - public static int MESSAGE_RUN_BENCHMARK = 1; - public static int MESSAGE_LLM_RUN_BENCHMARK = 2; + List mResult; - ModelRunner mModelRunner; - BenchmarkActivity mBenchmarkActivity; + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); - LlmModelRunner mLlmModelRunner; - LlmBenchmark mLlmBenchmark; + try { + Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + } catch (ErrnoException e) { + finish(); + } - public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { - super(looper); - mModelRunner = new ModelRunner(); - mBenchmarkActivity = benchmarkActivity; + Intent intent = getIntent(); + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + + int numIter = intent.getIntExtra("num_iter", 50); + int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + float temperature = intent.getFloatExtra("temperature", 0.8f); + String prompt = intent.getStringExtra("prompt"); + + mModel = model; + mNumIter = numIter; + mNumWarmupIter = numWarmupIter; + mTokenizerPath = tokenizerPath; + mTemperature = temperature; + mPrompt = prompt; + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + mResult = new ArrayList<>(); + + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); + } + + void writeResult() { + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mResult)); + } catch (IOException e) { + e.printStackTrace(); + } finally { + finish(); } + } +} - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_RUN_BENCHMARK) { - mModelRunner.runBenchmark(mBenchmarkActivity.mModel, mBenchmarkActivity.mNumWarmupIter, mBenchmarkActivity.mNumIter, mBenchmarkActivity.mResult); - - if (mBenchmarkActivity.mTokenizerPath == null) { - mBenchmarkActivity.writeResult(); - } else { - this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); - } - } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { - mLlmBenchmark = new LlmBenchmark(mBenchmarkActivity, mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, mBenchmarkActivity.mResult); - } +class BenchmarkHandler extends Handler { + public static int MESSAGE_RUN_BENCHMARK = 1; + public static int MESSAGE_LLM_RUN_BENCHMARK = 2; + + ModelRunner mModelRunner; + BenchmarkActivity mBenchmarkActivity; + + LlmModelRunner mLlmModelRunner; + LlmBenchmark mLlmBenchmark; + + public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { + super(looper); + mModelRunner = new ModelRunner(); + mBenchmarkActivity = benchmarkActivity; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_RUN_BENCHMARK) { + mModelRunner.runBenchmark( + mBenchmarkActivity.mModel, + mBenchmarkActivity.mNumWarmupIter, + mBenchmarkActivity.mNumIter, + mBenchmarkActivity.mResult); + + if (mBenchmarkActivity.mTokenizerPath == null) { + mBenchmarkActivity.writeResult(); + } else { + this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); + } + } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { + mLlmBenchmark = + new LlmBenchmark( + mBenchmarkActivity, + mBenchmarkActivity.mModel.getName(), + mBenchmarkActivity.mTokenizerPath, + mBenchmarkActivity.mPrompt, + mBenchmarkActivity.mTemperature, + mBenchmarkActivity.mResult); } + } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java index 0c09614780..66ab50550a 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkMetric.java @@ -10,66 +10,65 @@ import android.app.ActivityManager; import android.os.Build; - import java.util.regex.Matcher; import java.util.regex.Pattern; class BenchmarkMetric { - public static class BenchmarkModel { - // The model name, i.e. stories110M - String name; - String backend; - String quantization; + public static class BenchmarkModel { + // The model name, i.e. stories110M + String name; + String backend; + String quantization; - public BenchmarkModel(final String name, final String backend, final String quantization) { - this.name = name; - this.backend = backend; - this.quantization = quantization; - } + public BenchmarkModel(final String name, final String backend, final String quantization) { + this.name = name; + this.backend = backend; + this.quantization = quantization; } + } - BenchmarkModel benchmarkModel; + BenchmarkModel benchmarkModel; - // The metric name, i.e. TPS - String metric; + // The metric name, i.e. TPS + String metric; - // The actual value and the option target value - double actualValue; - double targetValue; + // The actual value and the option target value + double actualValue; + double targetValue; - public static class DeviceInfo { - // Let's see which information we want to include here - final String device = Build.BRAND; - // The phone model and Android release version - final String arch = Build.MODEL; - final String os = "Android " + Build.VERSION.RELEASE; - final long totalMem = new ActivityManager.MemoryInfo().totalMem; - final long availMem = new ActivityManager.MemoryInfo().availMem; - } + public static class DeviceInfo { + // Let's see which information we want to include here + final String device = Build.BRAND; + // The phone model and Android release version + final String arch = Build.MODEL; + final String os = "Android " + Build.VERSION.RELEASE; + final long totalMem = new ActivityManager.MemoryInfo().totalMem; + final long availMem = new ActivityManager.MemoryInfo().availMem; + } - DeviceInfo deviceInfo = new DeviceInfo(); + DeviceInfo deviceInfo = new DeviceInfo(); - public BenchmarkMetric( - final BenchmarkModel benchmarkModel, - final String metric, - final double actualValue, - final double targetValue) { - this.benchmarkModel = benchmarkModel; - this.metric = metric; - this.actualValue = actualValue; - this.targetValue = targetValue; - } + public BenchmarkMetric( + final BenchmarkModel benchmarkModel, + final String metric, + final double actualValue, + final double targetValue) { + this.benchmarkModel = benchmarkModel; + this.metric = metric; + this.actualValue = actualValue; + this.targetValue = targetValue; + } - // TODO (huydhn): Figure out a way to extract the backend and quantization information from - // the .pte model itself instead of parsing its name - public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { - final Matcher m = - Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); - if (m.matches()) { - return new BenchmarkMetric.BenchmarkModel( - m.group("name"), m.group("backend"), m.group("quantization")); - } else { - return new BenchmarkMetric.BenchmarkModel(model, "", ""); - } + // TODO (huydhn): Figure out a way to extract the backend and quantization information from + // the .pte model itself instead of parsing its name + public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { + final Matcher m = + Pattern.compile("(?\\w+)_(?[\\w\\+]+)_(?\\w+)").matcher(model); + if (m.matches()) { + return new BenchmarkMetric.BenchmarkModel( + m.group("name"), m.group("backend"), m.group("quantization")); + } else { + return new BenchmarkMetric.BenchmarkModel(model, "", ""); } + } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java index c29002ea9c..44d1665924 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java @@ -9,112 +9,115 @@ package org.pytorch.minibench; import android.util.Log; - +import java.util.List; import org.json.JSONException; import org.json.JSONObject; -import java.util.List; - public class LlmBenchmark implements LlmModelRunnerCallback { - LlmModelRunner mLlmModelRunner; + LlmModelRunner mLlmModelRunner; - String mPrompt; - StatsInfo mStatsInfo; + String mPrompt; + StatsInfo mStatsInfo; - List mResults; - BenchmarkActivity mActivity; + List mResults; + BenchmarkActivity mActivity; - LlmBenchmark(BenchmarkActivity activity, String modelFile, String tokenizerPath, String prompt, float temperature, List results) { - mResults = results; - mActivity = activity; - mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.replace(".pte", ""); - mPrompt = prompt; - mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); - mStatsInfo.loadStart = System.nanoTime(); - } + LlmBenchmark( + BenchmarkActivity activity, + String modelFile, + String tokenizerPath, + String prompt, + float temperature, + List results) { + mResults = results; + mActivity = activity; + mStatsInfo = new StatsInfo(); + mStatsInfo.modelName = modelFile.replace(".pte", ""); + mPrompt = prompt; + mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); + mStatsInfo.loadStart = System.nanoTime(); + } - @Override - public void onModelLoaded(int status) { - mStatsInfo.loadEnd = System.nanoTime(); - mStatsInfo.loadStatus = status; - if (status != 0) { - Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); - onGenerationStopped(); - return; - } - mStatsInfo.generateStart = System.nanoTime(); - mLlmModelRunner.generate(mPrompt); + @Override + public void onModelLoaded(int status) { + mStatsInfo.loadEnd = System.nanoTime(); + mStatsInfo.loadStatus = status; + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; } + mStatsInfo.generateStart = System.nanoTime(); + mLlmModelRunner.generate(mPrompt); + } - @Override - public void onTokenGenerated(String token) { - } + @Override + public void onTokenGenerated(String token) {} - @Override - public void onStats(String stats) { - float tps = 0; - try { - JSONObject jsonObject = new JSONObject(stats); - int numGeneratedTokens = jsonObject.getInt("generated_tokens"); - int inferenceEndMs = jsonObject.getInt("inference_end_ms"); - int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); - tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; - mStatsInfo.tps = tps; - } catch (JSONException e) { - Log.e("LLM", "Error parsing JSON: " + e.getMessage()); - } + @Override + public void onStats(String stats) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + mStatsInfo.tps = tps; + } catch (JSONException e) { + Log.e("LLM", "Error parsing JSON: " + e.getMessage()); } + } - @Override - public void onGenerationStopped() { - mStatsInfo.generateEnd = System.nanoTime(); + @Override + public void onGenerationStopped() { + mStatsInfo.generateEnd = System.nanoTime(); - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); - // The list of metrics we have atm includes: - // Load status - mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); - // Model load time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "llm_model_load_time(ms)", - (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, - 0.0f)); - // LLM generate time - mResults.add( - new BenchmarkMetric( - benchmarkModel, - "generate_time(ms)", - (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, - 0.0f)); - // Token per second - mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - mActivity.writeResult(); - } + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); + // The list of metrics we have atm includes: + // Load status + mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); + // Model load time + mResults.add( + new BenchmarkMetric( + benchmarkModel, + "llm_model_load_time(ms)", + (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, + 0.0f)); + // LLM generate time + mResults.add( + new BenchmarkMetric( + benchmarkModel, + "generate_time(ms)", + (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, + 0.0f)); + // Token per second + mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); + mActivity.writeResult(); + } } class StatsInfo { - int loadStatus; - long loadStart; - long loadEnd; - long generateStart; - long generateEnd; - float tps; - String modelName; + int loadStatus; + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + float tps; + String modelName; - @Override - public String toString() { - return "loadStart: " - + loadStart - + "\nloadEnd: " - + loadEnd - + "\ngenerateStart: " - + generateStart - + "\ngenerateEnd: " - + generateEnd - + "\n" - + tps; - } + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tps; + } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java index ffa998665c..a1b434a37b 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java @@ -12,89 +12,86 @@ import android.os.HandlerThread; import android.os.Looper; import android.os.Message; - import org.pytorch.executorch.extension.llm.LlmCallback; import org.pytorch.executorch.extension.llm.LlmModule; -/** - * A helper class to handle all model running logic within this class. - */ +/** A helper class to handle all model running logic within this class. */ public class LlmModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - LlmModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - LlmModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - LlmModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("LlmModelRunner"); - mHandlerThread.start(); - mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } + LlmModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + LlmModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + LlmModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + LlmModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("LlmModelRunner"); + mHandlerThread.start(); + mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(String result) { + mCallback.onStats(result); + } } class LlmModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final LlmModelRunner mLlmModelRunner; - - public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { - super(looper); - mLlmModelRunner = llmModelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mLlmModelRunner.mModule.load(); - mLlmModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - mLlmModelRunner.mCallback.onGenerationStopped(); - } + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final LlmModelRunner mLlmModelRunner; + + public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { + super(looper); + mLlmModelRunner = llmModelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mLlmModelRunner.mModule.load(); + mLlmModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); + mLlmModelRunner.mCallback.onGenerationStopped(); } + } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 190bf3284f..ffd7217167 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -9,74 +9,70 @@ package org.pytorch.minibench; import android.os.Debug; - -import org.pytorch.executorch.Module; - +import android.util.Log; import java.io.File; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import org.pytorch.executorch.Module; public class ModelRunner { - /** - * @return list of #BenchmarkMetric - */ - public void runBenchmark(File model, int numWarmupIter, int numIter, List results) { - long pssIdle = Debug.getPss(); + /** + * @return list of #BenchmarkMetric + */ + public void runBenchmark( + File model, int numWarmupIter, int numIter, List results) { + long pssIdle = Debug.getPss(); - List latency = new ArrayList<>(); + List latency = new ArrayList<>(); - long loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - int errorCode = module.loadMethod("forward"); - long loadEnd = System.nanoTime(); + long loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + int errorCode = module.loadMethod("forward"); + long loadEnd = System.nanoTime(); - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + latency.add(forwardMs); + } - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = - latency.subList(resultSize / 10, resultSize * 9 / 10); + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(latency); + int resultSize = latency.size(); + List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, - "model_load_time(ms)", - (loadEnd - loadStart) * 1e-6, - 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); - } + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + } } From 290e947ffc19bc26bf13c6eac203e00642177eb7 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 22 Apr 2025 22:46:04 -0700 Subject: [PATCH 3/3] Fix --- .../src/main/java/org/pytorch/minibench/BenchmarkActivity.java | 3 +-- .../app/src/main/java/org/pytorch/minibench/LlmBenchmark.java | 2 +- .../app/src/main/java/org/pytorch/minibench/ModelRunner.java | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index b956375688..5e1dd48926 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -16,7 +16,6 @@ import android.os.Looper; import android.system.ErrnoException; import android.system.Os; -import android.util.Log; import com.google.gson.Gson; import java.io.File; import java.io.FileWriter; @@ -127,7 +126,7 @@ public void handleMessage(android.os.Message msg) { mLlmBenchmark = new LlmBenchmark( mBenchmarkActivity, - mBenchmarkActivity.mModel.getName(), + mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java index 44d1665924..0c0436d267 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java @@ -32,7 +32,7 @@ public class LlmBenchmark implements LlmModelRunnerCallback { mResults = results; mActivity = activity; mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.replace(".pte", ""); + mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); mPrompt = prompt; mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); mStatsInfo.loadStart = System.nanoTime(); diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index ffd7217167..3913a8d76f 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -9,7 +9,6 @@ package org.pytorch.minibench; import android.os.Debug; -import android.util.Log; import java.io.File; import java.util.ArrayList; import java.util.Collections;