diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index c0856f3e4fe..4b2ba56099e 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import org.pytorch.executorch.Module; @@ -80,11 +81,18 @@ protected void onPostExecute(Void aVoid) { final List results = new ArrayList<>(); // The list of metrics we have atm includes: // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(stats.latency); + int resultSize = stats.latency.size(); + List usedLatencyResults = + stats.latency.subList(resultSize / 10, resultSize * 9 / 10); + results.add( new BenchmarkMetric( benchmarkModel, "avg_inference_latency(ms)", - stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), 0.0f)); // Model load time results.add(