Skip to content

Commit 9dfe343

Browse files
chuanqi129vpirogov
authored andcommitted
tests: benchdnn: graph: fix gpu engine performance measurement
1 parent 8d796ef commit 9dfe343

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

tests/benchdnn/graph/utils.cpp

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ inline int measure_perf_aggregate(timer::timer_t &t,
7878
std::vector<perf_function_t> &perf_func_v,
7979
const std::vector<std::vector<dnnl::graph::tensor>> &inputs_v,
8080
const std::vector<std::vector<dnnl::graph::tensor>> &outputs_v) {
81-
const int max_batch_times = 10000;
81+
const int max_batch_times = 4096;
8282
// Nvidia/AMD don't support profiling.
8383
const bool use_profiling = is_gpu() && !is_nvidia_gpu() && !is_amd_gpu();
8484
const dnnl::stream::flags flags = use_profiling
@@ -101,11 +101,12 @@ inline int measure_perf_aggregate(timer::timer_t &t,
101101
reset_gpu_profiling(((dnnl::stream)stream).get());
102102

103103
bool is_first_loop = true;
104+
size_t prim_num = 1;
104105
while (true) {
105-
for_(size_t i = 0; i < sz; i++)
106-
for (int j = 0; j < cur_batch_times; j++) {
106+
for_(int i = 0; i < cur_batch_times; i++)
107+
for (size_t j = 0; j < sz; j++) {
107108
DNN_GRAPH_SAFE(
108-
perf_func_v[i](stream, inputs_v[i], outputs_v[i]), WARN);
109+
perf_func_v[j](stream, inputs_v[j], outputs_v[j]), WARN);
109110
}
110111
DNN_GRAPH_SAFE(stream.wait(), WARN);
111112

@@ -115,11 +116,23 @@ inline int measure_perf_aggregate(timer::timer_t &t,
115116
get_gpu_profiling_info(((dnnl::stream)stream).get(), nsecs, cycles);
116117
reset_gpu_profiling(((dnnl::stream)stream).get());
117118

118-
// Profiling should have information to stop the cycle.
119-
if (nsecs.empty()) SAFE(FAIL, WARN);
120-
121-
for (size_t i = 0; i < nsecs.size(); i++) {
122-
t.stop(1, (int64_t)cycles[i], nsecs[i] / 1e6);
119+
// Profiling should have information to report, otherwise, stop.
120+
if (nsecs.empty()) {
121+
BENCHDNN_PRINT(0, "%s\n",
122+
"WARNING: no counters were found during profiling.");
123+
break;
124+
}
125+
// Calculate the number of primitives in a batch
126+
if (is_first_loop) { prim_num = nsecs.size() / cur_batch_times; }
127+
128+
for (int i = 0; i < cur_batch_times; i++) {
129+
int64_t cycles_res = 0;
130+
double nsecs_res = 0;
131+
for (size_t j = 0; j < prim_num; j++) {
132+
cycles_res += cycles[i * prim_num + j];
133+
nsecs_res += nsecs[i * prim_num + j];
134+
}
135+
t.stop(1, cycles_res, nsecs_res / 1e6);
123136
}
124137
} else {
125138
t.stamp(cur_batch_times);

0 commit comments

Comments
 (0)