Skip to content

Minibench refactor #10376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/benchmark/android/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench

### Generic model
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench
```

### LLM
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ phases:
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180

if [ -n "$BIN_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
elif [ -n "$MODEL_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
else
Expand Down
12 changes: 9 additions & 3 deletions extension/benchmark/android/benchmark/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

plugins { id("com.android.application") }
plugins { id("com.android.application")
id("org.jetbrains.kotlin.android")
}

android {
namespace = "org.pytorch.minibench"
Expand All @@ -29,8 +31,11 @@ android {
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "17"
}
}

Expand All @@ -40,6 +45,7 @@ dependencies {
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@
</intent-filter>
</activity>

<activity
android:name=".LlmBenchmarkActivity"
android:exported="true">
<intent-filter>
<action android:name="org.pytorch.minibench.BENCHMARK" />
</intent-filter>
</activity>

</application>

</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.os.Debug;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.system.ErrnoException;
import android.system.Os;
import com.google.gson.Gson;
Expand All @@ -21,12 +22,22 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.pytorch.executorch.Module;

public class BenchmarkActivity extends Activity {

File mModel;
int mNumIter;
int mNumWarmupIter;
String mTokenizerPath;
float mTemperature;
String mPrompt;

HandlerThread mHandlerThread;
BenchmarkHandler mHandler;

List<BenchmarkMetric> mResult;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand All @@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {

int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
String tokenizerPath = intent.getStringExtra("tokenizer_path");
float temperature = intent.getFloatExtra("temperature", 0.8f);
String prompt = intent.getStringExtra("prompt");

mModel = model;
mNumIter = numIter;
mNumWarmupIter = numWarmupIter;
mTokenizerPath = tokenizerPath;
mTemperature = temperature;
mPrompt = prompt;
if (mPrompt == null) {
mPrompt = "The ultimate answer";
}
mResult = new ArrayList<>();

long pssIdle = Debug.getPss();
mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);

// TODO: Format the string with a parsable format
Stats stats = new Stats();
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
}

new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {
void writeResult() {
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mResult));
} catch (IOException e) {
e.printStackTrace();
} finally {
finish();
}
}
}

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();
class BenchmarkHandler extends Handler {
public static int MESSAGE_RUN_BENCHMARK = 1;
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;

for (int i = 0; i < numWarmupIter; i++) {
module.forward();
}
ModelRunner mModelRunner;
BenchmarkActivity mBenchmarkActivity;

for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
return null;
}
LlmModelRunner mLlmModelRunner;
LlmBenchmark mLlmBenchmark;

@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
// Currently the result has large variance from outliers, so only use
// 80% samples in the middle (trimmean 0.2)
Collections.sort(stats.latency);
int resultSize = stats.latency.size();
List<Double> usedLatencyResults =
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);

results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
results.add(
new BenchmarkMetric(
benchmarkModel,
"trimmean_inference_latency(ms)",
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
// RAM PSS usage
results.add(
new BenchmarkMetric(
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}.execute();
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
super(looper);
mModelRunner = new ModelRunner();
mBenchmarkActivity = benchmarkActivity;
}
}

class Stats {
long loadStart;
long loadEnd;
List<Double> latency = new ArrayList<>();
int errorCode = 0;

@Override
public String toString() {
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
public void handleMessage(android.os.Message msg) {
if (msg.what == MESSAGE_RUN_BENCHMARK) {
mModelRunner.runBenchmark(
mBenchmarkActivity.mModel,
mBenchmarkActivity.mNumWarmupIter,
mBenchmarkActivity.mNumIter,
mBenchmarkActivity.mResult);

if (mBenchmarkActivity.mTokenizerPath == null) {
mBenchmarkActivity.writeResult();
} else {
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
}
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
mLlmBenchmark =
new LlmBenchmark(
mBenchmarkActivity,
mBenchmarkActivity.mModel.getPath(),
mBenchmarkActivity.mTokenizerPath,
mBenchmarkActivity.mPrompt,
mBenchmarkActivity.mTemperature,
mBenchmarkActivity.mResult);
}
}
}
Loading
Loading