Skip to content

Commit b88d1bb

Browse files
authored
Alternative approach for fixing losses.
1 parent 5777bfb commit b88d1bb

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

torchvision/ops/boxes.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -325,15 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
325325

326326
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
327327

328-
w_pred = boxes1[:, 2] - boxes1[:, 0]
329-
h_pred = boxes1[:, 3] - boxes1[:, 1]
328+
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
329+
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
330330

331-
w_gt = boxes2[:, 2] - boxes2[:, 0]
332-
h_gt = boxes2[:, 3] - boxes2[:, 1]
331+
w_gt = boxes2[:, None, 2] - boxes2[:, None, 0]
332+
h_gt = boxes2[:, None, 3] - boxes2[:, None, 1]
333333

334-
aspect_gt = torch.atan(w_gt / h_gt)
335-
aspect_pred = torch.atan(w_pred / h_pred)
336-
v = (4 / (torch.pi**2)) * torch.pow((aspect_gt - aspect_pred[:, None]), 2)
334+
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt).t(), 2)
337335
with torch.no_grad():
338336
alpha = v / (1 - iou + v + eps)
339337
return diou - alpha * v
@@ -360,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
360358

361359
boxes1 = _upcast(boxes1)
362360
boxes2 = _upcast(boxes2)
363-
diou, _ = _box_diou_iou(boxes1, boxes2, eps)
361+
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
364362
return diou
365363

366364

@@ -372,19 +370,15 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
372370
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
373371
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
374372
# centers of boxes
375-
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
376-
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
377-
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
378-
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
373+
x_p = (boxes1[:, None, 0] + boxes1[:, None, 2]) / 2
374+
y_p = (boxes1[:, None, 1] + boxes1[:, None, 3]) / 2
375+
x_g = (boxes2[:, None, 0] + boxes2[:, None, 2]) / 2
376+
y_g = (boxes2[:, None, 1] + boxes2[:, None, 3]) / 2
379377
# The distance between boxes' centers squared.
380-
centers_distance_squared = (_upcast((x_p - x_g[:, None]).diag()) ** 2) + (_upcast((y_p - y_g[:, None]).diag()) ** 2)
378+
centers_distance_squared = (_upcast(x_p - x_g.t()) ** 2) + (_upcast(y_p - y_g.t()) ** 2)
381379
# The distance IoU is the IoU penalized by a normalized
382380
# distance between boxes' centers squared.
383-
if boxes1.size(0) > boxes2.size(0):
384-
center_distance_ratio = centers_distance_squared[None, :] / diagonal_distance_squared
385-
else:
386-
center_distance_ratio = centers_distance_squared[:, None] / diagonal_distance_squared
387-
return iou - center_distance_ratio, iou
381+
return iou - (centers_distance_squared / diagonal_distance_squared), iou
388382

389383

390384
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)