1
1
from collections import OrderedDict
2
2
import torch
3
- import unittest
3
+ from common_utils import TestCase
4
4
from torchvision .models .detection .anchor_utils import AnchorGenerator
5
5
from torchvision .models .detection .image_list import ImageList
6
6
7
7
8
- class Tester (unittest . TestCase ):
8
+ class Tester (TestCase ):
9
9
def test_incorrect_anchors (self ):
10
10
incorrect_sizes = ((2 , 4 , 8 ), (32 , 8 ), )
11
11
incorrect_aspects = (0.5 , 1.0 )
@@ -16,40 +16,46 @@ def test_incorrect_anchors(self):
16
16
self .assertRaises (ValueError , anc , image_list , feature_maps )
17
17
18
18
def _init_test_anchor_generator (self ):
19
- anchor_sizes = tuple (( x ,) for x in [ 32 , 64 , 128 ] )
20
- aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len ( anchor_sizes )
19
+ anchor_sizes = (( 10 ,), )
20
+ aspect_ratios = ((1 ,), )
21
21
anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios )
22
22
23
23
return anchor_generator
24
24
25
25
def get_features (self , images ):
26
26
s0 , s1 = images .shape [- 2 :]
27
- features = [
28
- ('0' , torch .rand (2 , 8 , s0 // 4 , s1 // 4 )),
29
- ('1' , torch .rand (2 , 16 , s0 // 8 , s1 // 8 )),
30
- ('2' , torch .rand (2 , 32 , s0 // 16 , s1 // 16 )),
31
- ]
32
- features = OrderedDict (features )
27
+ features = [torch .rand (2 , 8 , s0 // 5 , s1 // 5 )]
33
28
return features
34
29
35
30
def test_anchor_generator (self ):
36
- images = torch .randn (2 , 3 , 16 , 32 )
31
+ images = torch .randn (2 , 3 , 15 , 15 )
37
32
features = self .get_features (images )
38
- features = list (features .values ())
39
33
image_shapes = [i .shape [- 2 :] for i in images ]
40
34
images = ImageList (images , image_shapes )
41
35
42
36
model = self ._init_test_anchor_generator ()
43
37
model .eval ()
44
38
anchors = model (images , features )
45
39
46
- # Compute target anchors numbers
40
+ # Estimate the number of target anchors
47
41
grid_sizes = [f .shape [- 2 :] for f in features ]
48
42
num_anchors_estimated = 0
49
43
for sizes , num_anchors_per_loc in zip (grid_sizes , model .num_anchors_per_location ()):
50
44
num_anchors_estimated += sizes [0 ] * sizes [1 ] * num_anchors_per_loc
51
45
52
- self .assertEqual (num_anchors_estimated , 126 )
46
+ anchors_output = torch .tensor ([[- 5. , - 5. , 5. , 5. ],
47
+ [0. , - 5. , 10. , 5. ],
48
+ [5. , - 5. , 15. , 5. ],
49
+ [- 5. , 0. , 5. , 10. ],
50
+ [0. , 0. , 10. , 10. ],
51
+ [5. , 0. , 15. , 10. ],
52
+ [- 5. , 5. , 5. , 15. ],
53
+ [0. , 5. , 10. , 15. ],
54
+ [5. , 5. , 15. , 15. ]])
55
+
56
+ self .assertEqual (num_anchors_estimated , 9 )
53
57
self .assertEqual (len (anchors ), 2 )
54
- self .assertEqual (tuple (anchors [0 ].shape ), (num_anchors_estimated , 4 ))
55
- self .assertEqual (tuple (anchors [1 ].shape ), (num_anchors_estimated , 4 ))
58
+ self .assertEqual (tuple (anchors [0 ].shape ), (9 , 4 ))
59
+ self .assertEqual (tuple (anchors [1 ].shape ), (9 , 4 ))
60
+ self .assertEqual (anchors [0 ], anchors_output )
61
+ self .assertEqual (anchors [1 ], anchors_output )
0 commit comments