@@ -325,15 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
325325
326326 diou , iou = _box_diou_iou (boxes1 , boxes2 , eps )
327327
328- w_pred = boxes1 [:, 2 ] - boxes1 [:, 0 ]
329- h_pred = boxes1 [:, 3 ] - boxes1 [:, 1 ]
328+ w_pred = boxes1 [:, None , 2 ] - boxes1 [:, None , 0 ]
329+ h_pred = boxes1 [:, None , 3 ] - boxes1 [:, None , 1 ]
330330
331- w_gt = boxes2 [:, 2 ] - boxes2 [:, 0 ]
332- h_gt = boxes2 [:, 3 ] - boxes2 [:, 1 ]
331+ w_gt = boxes2 [:, None , 2 ] - boxes2 [:, None , 0 ]
332+ h_gt = boxes2 [:, None , 3 ] - boxes2 [:, None , 1 ]
333333
334- aspect_gt = torch .atan (w_gt / h_gt )
335- aspect_pred = torch .atan (w_pred / h_pred )
336- v = (4 / (torch .pi ** 2 )) * torch .pow ((aspect_gt - aspect_pred [:, None ]), 2 )
334+ v = (4 / (torch .pi ** 2 )) * torch .pow (torch .atan (w_pred / h_pred ) - torch .atan (w_gt / h_gt ).t (), 2 )
337335 with torch .no_grad ():
338336 alpha = v / (1 - iou + v + eps )
339337 return diou - alpha * v
@@ -360,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
360358
361359 boxes1 = _upcast (boxes1 )
362360 boxes2 = _upcast (boxes2 )
363- diou , _ = _box_diou_iou (boxes1 , boxes2 , eps )
361+ diou , _ = _box_diou_iou (boxes1 , boxes2 , eps = eps )
364362 return diou
365363
366364
@@ -372,19 +370,15 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
372370 whi = _upcast (rbi - lti ).clamp (min = 0 ) # [N,M,2]
373371 diagonal_distance_squared = (whi [:, :, 0 ] ** 2 ) + (whi [:, :, 1 ] ** 2 ) + eps
374372 # centers of boxes
375- x_p = (boxes1 [:, 0 ] + boxes1 [:, 2 ]) / 2
376- y_p = (boxes1 [:, 1 ] + boxes1 [:, 3 ]) / 2
377- x_g = (boxes2 [:, 0 ] + boxes2 [:, 2 ]) / 2
378- y_g = (boxes2 [:, 1 ] + boxes2 [:, 3 ]) / 2
373+ x_p = (boxes1 [:, None , 0 ] + boxes1 [:, None , 2 ]) / 2
374+ y_p = (boxes1 [:, None , 1 ] + boxes1 [:, None , 3 ]) / 2
375+ x_g = (boxes2 [:, None , 0 ] + boxes2 [:, None , 2 ]) / 2
376+ y_g = (boxes2 [:, None , 1 ] + boxes2 [:, None , 3 ]) / 2
379377 # The distance between boxes' centers squared.
380- centers_distance_squared = (_upcast (( x_p - x_g [:, None ]). diag ()) ** 2 ) + (_upcast (( y_p - y_g [:, None ]). diag ()) ** 2 )
378+ centers_distance_squared = (_upcast (x_p - x_g . t ()) ** 2 ) + (_upcast (y_p - y_g . t ()) ** 2 )
381379 # The distance IoU is the IoU penalized by a normalized
382380 # distance between boxes' centers squared.
383- if boxes1 .size (0 ) > boxes2 .size (0 ):
384- center_distance_ratio = centers_distance_squared [None , :] / diagonal_distance_squared
385- else :
386- center_distance_ratio = centers_distance_squared [:, None ] / diagonal_distance_squared
387- return iou - center_distance_ratio , iou
381+ return iou - (centers_distance_squared / diagonal_distance_squared ), iou
388382
389383
390384def masks_to_boxes (masks : torch .Tensor ) -> torch .Tensor :
0 commit comments