diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index c43a788063e..455c9af34e9 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -2,6 +2,7 @@ from torch import Tensor from ..utils import _log_api_usage_once +from .boxes import generalized_box_iou def _upcast(t: Tensor) -> Tensor: @@ -48,30 +49,14 @@ def generalized_box_iou_loss( boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) - x1, y1, x2, y2 = boxes1.unbind(dim=-1) - x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsctk = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk - iouk = intsctk / (unionk + eps) - - # smallest enclosing box - xc1 = torch.min(x1, x1g) - yc1 = torch.min(y1, y1g) - xc2 = torch.max(x2, x2g) - yc2 = torch.max(y2, y2g) - - area_c = (xc2 - xc1) * (yc2 - yc1) - miouk = iouk - ((area_c - unionk) / (area_c + eps)) + # here we can use diagonal matrix with no problem because the resultant tensor of `generalized_box_iou` will be having shape = N X N + if boxes1.shape == torch.Size([4]): + boxes1 = boxes1[None, :] + boxes2 = boxes2[None, :] + miouk = generalized_box_iou(boxes1, boxes2)[0][0] + else: + miouk = generalized_box_iou(boxes1, boxes2).diagonal() loss = 1 - miouk if reduction == "mean":