@@ -449,6 +449,18 @@ def test_autocast(self):
449
449
with torch .cuda .amp .autocast ():
450
450
self .test_nms_cuda (dtype = dtype )
451
451
452
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA unavailable" )
453
+ def test_nms_cuda_float16 (self ):
454
+ boxes = torch .tensor ([[285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
455
+ [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
456
+ [279.2440 , 197.9812 , 1189.4746 , 849.2019 ]]).cuda ()
457
+ scores = torch .tensor ([0.6370 , 0.7569 , 0.3966 ]).cuda ()
458
+
459
+ iou_thres = 0.2
460
+ keep32 = ops .nms (boxes , scores , iou_thres )
461
+ keep16 = ops .nms (boxes .to (torch .float16 ), scores .to (torch .float16 ), iou_thres )
462
+ self .assertTrue (torch .all (torch .eq (keep32 , keep16 )))
463
+
452
464
453
465
class DeformConvTester (OpTester , unittest .TestCase ):
454
466
def expected_fn (self , x , weight , offset , mask , bias , stride = 1 , padding = 0 , dilation = 1 ):
@@ -829,48 +841,75 @@ def test_bbox_convert_jit(self):
829
841
830
842
class BoxAreaTester (unittest .TestCase ):
831
843
def test_box_area (self ):
832
- # A bounding box of area 10000 and a degenerate case
833
- box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float )
834
- expected = torch .tensor ([10000 , 0 ])
835
- calc_area = ops .box_area (box_tensor )
836
- assert calc_area .size () == torch .Size ([2 ])
837
- assert calc_area .dtype == box_tensor .dtype
838
- assert torch .all (torch .eq (calc_area , expected )).item () is True
844
+ def area_check (box , expected , tolerance = 1e-4 ):
845
+ out = ops .box_area (box )
846
+ assert out .size () == expected .size ()
847
+ assert ((out - expected ).abs ().max () < tolerance ).item ()
848
+
849
+ # Check for int boxes
850
+ for dtype in [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]:
851
+ box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype )
852
+ expected = torch .tensor ([10000 , 0 ])
853
+ area_check (box_tensor , expected )
854
+
855
+ # Check for float32 and float64 boxes
856
+ for dtype in [torch .float32 , torch .float64 ]:
857
+ box_tensor = torch .tensor ([[285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
858
+ [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
859
+ [279.2440 , 197.9812 , 1189.4746 , 849.2019 ]], dtype = dtype )
860
+ expected = torch .tensor ([604723.0806 , 600965.4666 , 592761.0085 ], dtype = torch .float64 )
861
+ area_check (box_tensor , expected , tolerance = 0.05 )
862
+
863
+ # Check for float16 box
864
+ box_tensor = torch .tensor ([[285.25 , 185.625 , 1194.0 , 851.5 ],
865
+ [285.25 , 188.75 , 1192.0 , 851.0 ],
866
+ [279.25 , 198.0 , 1189.0 , 849.0 ]], dtype = torch .float16 )
867
+ expected = torch .tensor ([605113.875 , 600495.1875 , 592247.25 ])
868
+ area_check (box_tensor , expected )
839
869
840
870
841
871
class BoxIouTester (unittest .TestCase ):
842
872
def test_iou (self ):
843
- # Boxes to test Iou
844
- boxes1 = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
845
- boxes2 = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
846
-
847
- # Expected IoU matrix for these boxes
848
- expected = torch .tensor ([[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
849
-
850
- out = ops .box_iou (boxes1 , boxes2 )
851
-
852
- # Check if all elements of tensor are as expected.
853
- assert out .size () == torch .Size ([3 , 3 ])
854
- tolerance = 1e-4
855
- assert ((out - expected ).abs ().max () < tolerance ).item () is True
873
+ def iou_check (box , expected , tolerance = 1e-4 ):
874
+ out = ops .box_iou (box , box )
875
+ assert out .size () == expected .size ()
876
+ assert ((out - expected ).abs ().max () < tolerance ).item ()
877
+
878
+ # Check for int boxes
879
+ for dtype in [torch .int16 , torch .int32 , torch .int64 ]:
880
+ box = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = dtype )
881
+ expected = torch .tensor ([[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]])
882
+ iou_check (box , expected )
883
+
884
+ # Check for float boxes
885
+ for dtype in [torch .float16 , torch .float32 , torch .float64 ]:
886
+ box_tensor = torch .tensor ([[285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
887
+ [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
888
+ [279.2440 , 197.9812 , 1189.4746 , 849.2019 ]], dtype = dtype )
889
+ expected = torch .tensor ([[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]])
890
+ iou_check (box_tensor , expected , tolerance = 0.002 if dtype == torch .float16 else 1e-4 )
856
891
857
892
858
893
class GenBoxIouTester (unittest .TestCase ):
859
894
def test_gen_iou (self ):
860
- # Test Generalized IoU
861
- boxes1 = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
862
- boxes2 = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = torch .float )
863
-
864
- # Expected gIoU matrix for these boxes
865
- expected = torch .tensor ([[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ],
866
- [- 0.7778 , - 0.8611 , 1.0 ]])
867
-
868
- out = ops .generalized_box_iou (boxes1 , boxes2 )
869
-
870
- # Check if all elements of tensor are as expected.
871
- assert out .size () == torch .Size ([3 , 3 ])
872
- tolerance = 1e-4
873
- assert ((out - expected ).abs ().max () < tolerance ).item () is True
895
+ def gen_iou_check (box , expected , tolerance = 1e-4 ):
896
+ out = ops .generalized_box_iou (box , box )
897
+ assert out .size () == expected .size ()
898
+ assert ((out - expected ).abs ().max () < tolerance ).item ()
899
+
900
+ # Check for int boxes
901
+ for dtype in [torch .int16 , torch .int32 , torch .int64 ]:
902
+ box = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]], dtype = dtype )
903
+ expected = torch .tensor ([[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ], [- 0.7778 , - 0.8611 , 1.0 ]])
904
+ gen_iou_check (box , expected )
905
+
906
+ # Check for float boxes
907
+ for dtype in [torch .float16 , torch .float32 , torch .float64 ]:
908
+ box_tensor = torch .tensor ([[285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
909
+ [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
910
+ [279.2440 , 197.9812 , 1189.4746 , 849.2019 ]], dtype = dtype )
911
+ expected = torch .tensor ([[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]])
912
+ gen_iou_check (box_tensor , expected , tolerance = 0.002 if dtype == torch .float16 else 1e-3 )
874
913
875
914
876
915
if __name__ == '__main__' :
0 commit comments