Skip to content

Commit 9c79934

Browse files
authored
Simplify the setup for AnchorGenerator in unittest (#3023)
1 parent 8c28175 commit 9c79934

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed
Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections import OrderedDict
22
import torch
3-
import unittest
3+
from common_utils import TestCase
44
from torchvision.models.detection.anchor_utils import AnchorGenerator
55
from torchvision.models.detection.image_list import ImageList
66

77

8-
class Tester(unittest.TestCase):
8+
class Tester(TestCase):
99
def test_incorrect_anchors(self):
1010
incorrect_sizes = ((2, 4, 8), (32, 8), )
1111
incorrect_aspects = (0.5, 1.0)
@@ -16,40 +16,46 @@ def test_incorrect_anchors(self):
1616
self.assertRaises(ValueError, anc, image_list, feature_maps)
1717

1818
def _init_test_anchor_generator(self):
19-
anchor_sizes = tuple((x,) for x in [32, 64, 128])
20-
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
19+
anchor_sizes = ((10,),)
20+
aspect_ratios = ((1,),)
2121
anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
2222

2323
return anchor_generator
2424

2525
def get_features(self, images):
2626
s0, s1 = images.shape[-2:]
27-
features = [
28-
('0', torch.rand(2, 8, s0 // 4, s1 // 4)),
29-
('1', torch.rand(2, 16, s0 // 8, s1 // 8)),
30-
('2', torch.rand(2, 32, s0 // 16, s1 // 16)),
31-
]
32-
features = OrderedDict(features)
27+
features = [torch.rand(2, 8, s0 // 5, s1 // 5)]
3328
return features
3429

3530
def test_anchor_generator(self):
36-
images = torch.randn(2, 3, 16, 32)
31+
images = torch.randn(2, 3, 15, 15)
3732
features = self.get_features(images)
38-
features = list(features.values())
3933
image_shapes = [i.shape[-2:] for i in images]
4034
images = ImageList(images, image_shapes)
4135

4236
model = self._init_test_anchor_generator()
4337
model.eval()
4438
anchors = model(images, features)
4539

46-
# Compute target anchors numbers
40+
# Estimate the number of target anchors
4741
grid_sizes = [f.shape[-2:] for f in features]
4842
num_anchors_estimated = 0
4943
for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()):
5044
num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc
5145

52-
self.assertEqual(num_anchors_estimated, 126)
46+
anchors_output = torch.tensor([[-5., -5., 5., 5.],
47+
[0., -5., 10., 5.],
48+
[5., -5., 15., 5.],
49+
[-5., 0., 5., 10.],
50+
[0., 0., 10., 10.],
51+
[5., 0., 15., 10.],
52+
[-5., 5., 5., 15.],
53+
[0., 5., 10., 15.],
54+
[5., 5., 15., 15.]])
55+
56+
self.assertEqual(num_anchors_estimated, 9)
5357
self.assertEqual(len(anchors), 2)
54-
self.assertEqual(tuple(anchors[0].shape), (num_anchors_estimated, 4))
55-
self.assertEqual(tuple(anchors[1].shape), (num_anchors_estimated, 4))
58+
self.assertEqual(tuple(anchors[0].shape), (9, 4))
59+
self.assertEqual(tuple(anchors[1].shape), (9, 4))
60+
self.assertEqual(anchors[0], anchors_output)
61+
self.assertEqual(anchors[1], anchors_output)

0 commit comments

Comments
 (0)