diff --git a/test/test_ops.py b/test/test_ops.py index 96cfb630e8d..7f9af319ed0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1266,7 +1266,11 @@ def _generate_int_input(): return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + return [ + [1.0000, 0.1875, -0.4444], + [0.1875, 1.0000, -0.5625], + [-0.4444, -0.5625, 1.0000], + ] def _generate_float_input(): return [ @@ -1357,7 +1361,11 @@ def _generate_int_input() -> List[List[int]]: return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + return [ + [1.0000, 0.1875, -0.4444], + [0.1875, 1.0000, -0.5625], + [-0.4444, -0.5625, 1.0000], + ] def _generate_float_input() -> List[List[float]]: return [ diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 72c95442b78..189d0a07bf3 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso diou, iou = _box_diou_iou(boxes1, boxes2, eps) - w_pred = boxes1[:, 2] - boxes1[:, 0] - h_pred = boxes1[:, 3] - boxes1[:, 1] + w_pred = boxes1[:, None, 2] - boxes1[:, None, 0] + h_pred = boxes1[:, None, 3] - boxes1[:, None, 1] w_gt = boxes2[:, 2] - boxes2[:, 0] h_gt = boxes2[:, 3] - boxes2[:, 1] - v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + v = (4 / (torch.pi ** 2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) return diou - alpha * v @@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) - diou, _ = _box_diou_iou(boxes1, boxes2) + diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps) return diou @@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 # The distance between boxes' centers squared. - centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) + centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + ( + _upcast((y_p[:, None] - y_g[None, :])) ** 2 + ) # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index 1f271fb0a1d..c30c54dfc9b 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -14,8 +14,8 @@ def complete_box_iou_loss( """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the - boxes do not overlap overlap area, This loss function considers important geometrical - factors such as overlap area, normalized central point distance and aspect ratio. + boxes do not overlap. This loss function considers important geometrical + factors such as overlap area, normalized central point distance and aspect ratio. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with @@ -35,7 +35,7 @@ def complete_box_iou_loss( Tensor: Loss tensor with the reduction option applied. Reference: - Zhaohui Zheng et. al: Complete Intersection over Union Loss: + Zhaohui Zheng et al.: Complete Intersection over Union Loss: https://arxiv.org/abs/1911.08287 """