@@ -78,7 +78,7 @@ inline int measure_perf_aggregate(timer::timer_t &t,
7878 std::vector<perf_function_t > &perf_func_v,
7979 const std::vector<std::vector<dnnl::graph::tensor>> &inputs_v,
8080 const std::vector<std::vector<dnnl::graph::tensor>> &outputs_v) {
81- const int max_batch_times = 10000 ;
81+ const int max_batch_times = 4096 ;
8282 // Nvidia/AMD don't support profiling.
8383 const bool use_profiling = is_gpu () && !is_nvidia_gpu () && !is_amd_gpu ();
8484 const dnnl::stream::flags flags = use_profiling
@@ -101,11 +101,12 @@ inline int measure_perf_aggregate(timer::timer_t &t,
101101 reset_gpu_profiling (((dnnl::stream)stream).get ());
102102
103103 bool is_first_loop = true ;
104+ size_t prim_num = 1 ;
104105 while (true ) {
105- for_ (size_t i = 0 ; i < sz ; i++)
106- for (int j = 0 ; j < cur_batch_times ; j++) {
106+ for_ (int i = 0 ; i < cur_batch_times ; i++)
107+ for (size_t j = 0 ; j < sz ; j++) {
107108 DNN_GRAPH_SAFE (
108- perf_func_v[i ](stream, inputs_v[i ], outputs_v[i ]), WARN);
109+ perf_func_v[j ](stream, inputs_v[j ], outputs_v[j ]), WARN);
109110 }
110111 DNN_GRAPH_SAFE (stream.wait (), WARN);
111112
@@ -115,11 +116,23 @@ inline int measure_perf_aggregate(timer::timer_t &t,
115116 get_gpu_profiling_info (((dnnl::stream)stream).get (), nsecs, cycles);
116117 reset_gpu_profiling (((dnnl::stream)stream).get ());
117118
118- // Profiling should have information to stop the cycle.
119- if (nsecs.empty ()) SAFE (FAIL, WARN);
120-
121- for (size_t i = 0 ; i < nsecs.size (); i++) {
122- t.stop (1 , (int64_t )cycles[i], nsecs[i] / 1e6 );
119+ // Profiling should have information to report, otherwise, stop.
120+ if (nsecs.empty ()) {
121+ BENCHDNN_PRINT (0 , " %s\n " ,
122+ " WARNING: no counters were found during profiling." );
123+ break ;
124+ }
125+ // Calculate the number of primitives in a batch
126+ if (is_first_loop) { prim_num = nsecs.size () / cur_batch_times; }
127+
128+ for (int i = 0 ; i < cur_batch_times; i++) {
129+ int64_t cycles_res = 0 ;
130+ double nsecs_res = 0 ;
131+ for (size_t j = 0 ; j < prim_num; j++) {
132+ cycles_res += cycles[i * prim_num + j];
133+ nsecs_res += nsecs[i * prim_num + j];
134+ }
135+ t.stop (1 , cycles_res, nsecs_res / 1e6 );
123136 }
124137 } else {
125138 t.stamp (cur_batch_times);
0 commit comments