1- import torch
2- import time
3- from tqdm import tqdm
1+ from .common .tools import get_model_size ,get_model_time
42class Torch_Tools (object ):
53 """
64 Pytorch操作
@@ -10,41 +8,21 @@ def __init__(self):
108 pass
119
1210 @staticmethod
13- def cal_model_time (input_shape , model , warmup_nums = 100 , iter_nums = 300 ):
11+ def get_model_info (input_shape , model ):
1412 """
15- 统计 模型前向耗时
13+ 获取模型信息,包括模型大小、前向推理耗时等
1614
1715 input_shape: 输入形状 eg:[1,3,224,224]
1816 model: 模型
19- warmup_nums: 预热次数
20- iter_nums: 总迭代次数,计算平均耗时
2117 """
22- img = torch . ones ( input_shape )
18+ result_dict = { " input_shape" : input_shape }
2319
24- flag = model .training # 记录模型是训练模式或评估模式
25- model .eval ()
26-
27- device_list = ["cpu" , "cuda:0" ] if torch .cuda .is_available () else ["cpu" ]
28- time_dict = {"input_shape" : input_shape }
29-
30- for device in device_list :
31- img = img .to (device )
32- model .to (device )
33-
34- # 预热
35- for _ in range (warmup_nums ):
36- model (img )
37- if "cuda" in device :
38- torch .cuda .synchronize ()
39- # 正式
40- start = time .time ()
41- for _ in tqdm (range (iter_nums )):
42- model (img )
43- # 每次推理,均同步一次。算均值
44- if "cuda" in device :
45- torch .cuda .synchronize ()
46- end = time .time ()
47- total_time = ((end - start ) * 1000 ) / float (iter_nums )
48- time_dict [device ] = "%.2f ms/img" % total_time
49- model .training = flag # 模型恢复为原状态
50- return time_dict
20+ # 获取模型大小
21+ size_dict = get_model_size (model )
22+ result_dict .update (size_dict )
23+
24+ # 前向推理耗时
25+ time_dict = get_model_time (input_shape , model )
26+ result_dict .update (time_dict )
27+
28+ return result_dict
0 commit comments