Skip to content

Commit 1e3bc2c

Browse files
authored
获取模型信息
1 parent f81a687 commit 1e3bc2c

File tree

2 files changed

+58
-35
lines changed

2 files changed

+58
-35
lines changed

bobotools/common/tools.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import time
3+
import os
4+
cur_path = os.path.abspath(os.path.dirname(__file__))
5+
6+
7+
def get_model_size(model):
8+
'''
9+
获取模型大小(MB)
10+
'''
11+
model_path = os.path.join(cur_path, str(time.time())+"_temp.pt")
12+
torch.save(model, model_path)
13+
model_size = os.path.getsize(model_path) / float(1024 * 1024)
14+
os.remove(model_path)
15+
return {"model_size(MB)":round(model_size, 2)}
16+
17+
@torch.no_grad()
18+
def get_model_time(input_shape,model,warmup_nums=100, iter_nums=300):
19+
'''
20+
获取模型前向耗时
21+
'''
22+
time_dict={}
23+
24+
img = torch.ones(input_shape)
25+
device_list = ["cuda:0","cpu"] if torch.cuda.is_available() else ["cpu"]
26+
for device in device_list:
27+
img = img.to(device)
28+
model.to(device)
29+
30+
# 预热
31+
for _ in range(warmup_nums):
32+
model(img)
33+
if "cuda" in device:
34+
torch.cuda.synchronize()
35+
# 正式
36+
start = time.time()
37+
for _ in range(iter_nums):
38+
model(img)
39+
# 每次推理,均同步一次。算均值
40+
if "cuda" in device:
41+
torch.cuda.synchronize()
42+
end = time.time()
43+
total_time = ((end - start) * 1000) / float(iter_nums)
44+
time_dict[device+"(ms)"]=round(total_time, 2)
45+
return time_dict

bobotools/torch_tools.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import torch
2-
import time
3-
from tqdm import tqdm
1+
from .common.tools import get_model_size,get_model_time
42
class Torch_Tools(object):
53
"""
64
Pytorch操作
@@ -10,41 +8,21 @@ def __init__(self):
108
pass
119

1210
@staticmethod
13-
def cal_model_time(input_shape, model, warmup_nums=100, iter_nums=300):
11+
def get_model_info(input_shape, model):
1412
"""
15-
统计 模型前向耗时
13+
获取模型信息,包括模型大小、前向推理耗时等
1614
1715
input_shape: 输入形状 eg:[1,3,224,224]
1816
model: 模型
19-
warmup_nums: 预热次数
20-
iter_nums: 总迭代次数,计算平均耗时
2117
"""
22-
img = torch.ones(input_shape)
18+
result_dict = {"input_shape": input_shape}
2319

24-
flag = model.training # 记录模型是训练模式或评估模式
25-
model.eval()
26-
27-
device_list = ["cpu", "cuda:0"] if torch.cuda.is_available() else ["cpu"]
28-
time_dict = {"input_shape": input_shape}
29-
30-
for device in device_list:
31-
img = img.to(device)
32-
model.to(device)
33-
34-
# 预热
35-
for _ in range(warmup_nums):
36-
model(img)
37-
if "cuda" in device:
38-
torch.cuda.synchronize()
39-
# 正式
40-
start = time.time()
41-
for _ in tqdm(range(iter_nums)):
42-
model(img)
43-
# 每次推理,均同步一次。算均值
44-
if "cuda" in device:
45-
torch.cuda.synchronize()
46-
end = time.time()
47-
total_time = ((end - start) * 1000) / float(iter_nums)
48-
time_dict[device] = "%.2f ms/img" % total_time
49-
model.training = flag # 模型恢复为原状态
50-
return time_dict
20+
# 获取模型大小
21+
size_dict=get_model_size(model)
22+
result_dict.update(size_dict)
23+
24+
# 前向推理耗时
25+
time_dict=get_model_time(input_shape, model)
26+
result_dict.update(time_dict)
27+
28+
return result_dict

0 commit comments

Comments
 (0)