Skip to content

Commit 24999db

Browse files
author
Mr.Li
committed
模型统计:不改变模型原状态
1 parent b528967 commit 24999db

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# bobotools
2-
2+
[![OSCS Status](https://www.oscs1024.com/platform/badge/bobo0810/bobotools.svg?size=small)](https://www.oscs1024.com/project/bobo0810/bobotools?ref=badge_small)
33
- 收录到[PytorchNetHub](https://github.com/bobo0810/PytorchNetHub)
44

55
## 安装

bobotools/com.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def get_model_size(model):
1313
torch.save(model, model_path)
1414
model_size = os.path.getsize(model_path) / float(1024 * 1024)
1515
os.remove(model_path)
16-
return {"model_size":str(round(model_size, 2))+"MB"}
16+
return {"model_size":str(round(model_size, 2))+" MB"}
1717

1818
def get_model_complexity(input_shape,model):
1919
'''
@@ -32,7 +32,10 @@ def get_model_time(input_shape,model,warmup_nums=100, iter_nums=300):
3232

3333
img = torch.ones(input_shape)
3434

35-
flag = model.training # 记录模型是训练模式或评估模式
35+
# 记录原模型状态
36+
ori_flag = model.training
37+
ori_device= next(model.parameters()).device
38+
3639
model.eval()
3740

3841
device_list = ["cuda:0","cpu"] if torch.cuda.is_available() else ["cpu"]
@@ -54,6 +57,8 @@ def get_model_time(input_shape,model,warmup_nums=100, iter_nums=300):
5457
torch.cuda.synchronize()
5558
end = time.time()
5659
total_time = ((end - start) * 1000) / float(iter_nums)
57-
time_dict[device]=str(round(total_time, 2))+"ms"
58-
model.training = flag # 模型恢复为原状态
60+
time_dict[device]=str(round(total_time, 2))+" ms"
61+
# 恢复到原状态
62+
model.training = ori_flag
63+
model.to(ori_device)
5964
return time_dict

0 commit comments

Comments
 (0)