-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathgradcam.py
58 lines (50 loc) · 2.2 KB
/
gradcam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
class GradCam:
def __init__(self, model):
self.model = model.eval()
self.feature = None
self.gradient = None
def save_gradient(self, grad):
self.gradient = grad
def __call__(self, images):
image_size = (images.size(-1), images.size(-2))
for name, module in self.model.named_children():
if name == 'conv1':
out = module(images)
conv1_heat_maps = out.mean(dim=1, keepdim=True)
features_heat_maps = []
for i in range(images.size(0)):
img = images[i].detach().cpu().numpy()
img = img - np.min(img)
if np.max(img) != 0:
img = img / np.max(img)
feature = images[i].unsqueeze(0)
for name, module in self.model.named_children():
if name == 'classifier':
feature = feature.view(feature.size(0), -1)
feature = module(feature)
if name == 'features':
feature.register_hook(self.save_gradient)
self.feature = feature
classes = torch.sigmoid(feature)
one_hot, _ = classes.max(dim=-1)
self.model.zero_grad()
one_hot.backward()
weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
mask = F.relu((weight * self.feature).sum(dim=1)).squeeze(0)
mask = cv2.resize(mask.detach().cpu().numpy(), image_size)
mask = mask - np.min(mask)
if np.max(mask) != 0:
mask = mask / np.max(mask)
heat_map = np.float32(cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET))
cam = heat_map + np.float32((np.uint8(img.transpose((1, 2, 0)) * 255)))
cam = cam - np.min(cam)
if np.max(cam) != 0:
cam = cam / np.max(cam)
features_heat_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2RGB)))
features_heat_maps = torch.stack(features_heat_maps)
return conv1_heat_maps, features_heat_maps