From 8bbcd2c9eadfee5d4cb7d6a79ed3d6b16e879fc9 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 16 May 2022 18:22:45 +0530 Subject: [PATCH 1/2] Cleanup ops --- torchvision/ops/__init__.py | 2 +- torchvision/ops/_utils.py | 37 ++++++++++++++++++++++++++++ torchvision/ops/boxes.py | 40 +++++++----------------------- torchvision/ops/ciou_loss.py | 47 ++++++++++-------------------------- torchvision/ops/diou_loss.py | 47 ++++++++++++++++++------------------ torchvision/ops/giou_loss.py | 34 +++++++++----------------- 6 files changed, 95 insertions(+), 112 deletions(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index cd711578a6c..d3f27ef1657 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -5,13 +5,13 @@ remove_small_boxes, clip_boxes_to_image, box_area, + box_convert, box_iou, generalized_box_iou, distance_box_iou, complete_box_iou, masks_to_boxes, ) -from .boxes import box_convert from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 8a02490ab13..a6ca557a98b 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -67,3 +67,40 @@ def split_normalization_params( else: other_params.extend(p for p in module.parameters() if p.requires_grad) return norm_params, other_params + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def _upcast_non_float(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.dtype not in (torch.float32, torch.float64): + return t.float() + return t + + +def _loss_inter_union( + boxes1: torch.Tensor, + boxes2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + + return intsctk, unionk diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 3b994879ecf..72c95442b78 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -7,6 +7,7 @@ from ..utils import _log_api_usage_once from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh +from ._utils import _upcast def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: @@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: return boxes -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - def box_area(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their @@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) - inter, union = _box_inter_union(boxes1, boxes2) - iou = inter / union - - lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2]) - rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:]) - - whi = (rbi - lti).clamp(min=0) # [N,M,2] - diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps - - # centers of boxes - x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 - y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 - x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 - y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 - # The distance between boxes' centers squared. - centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2 + diou, iou = _box_diou_iou(boxes1, boxes2, eps) w_pred = boxes1[:, 2] - boxes1[:, 0] h_pred = boxes1[:, 3] - boxes1[:, 1] @@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) - return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v + return diou - alpha * v def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: @@ -380,16 +358,17 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) + diou, _ = _box_diou_iou(boxes1, boxes2) + return diou - inter, union = _box_inter_union(boxes1, boxes2) - iou = inter / union +def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: + + iou = box_iou(boxes1, boxes2) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps - # centers of boxes x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 @@ -397,10 +376,9 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 # The distance between boxes' centers squared. centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) - # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. - return iou - (centers_distance_squared / diagonal_distance_squared) + return iou - (centers_distance_squared / diagonal_distance_squared), iou def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index d53e2d6af2a..1f271fb0a1d 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -1,7 +1,8 @@ import torch from ..utils import _log_api_usage_once -from .giou_loss import _upcast +from ._utils import _upcast_non_float +from .diou_loss import _diou_iou_loss def complete_box_iou_loss( @@ -30,50 +31,28 @@ def complete_box_iou_loss( ``'sum'``: The output will be summed. Default: ``'none'`` eps : (float): small number to prevent division by zero. Default: 1e-7 - Reference: + Returns: + Tensor: Loss tensor with the reduction option applied. - Complete Intersection over Union Loss (Zhaohui Zheng et. al) - https://arxiv.org/abs/1911.08287 + Reference: + Zhaohui Zheng et. al: Complete Intersection over Union Loss: + https://arxiv.org/abs/1911.08287 """ - # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) + + diou_loss, iou = _diou_iou_loss(boxes1, boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsct = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps - iou = intsct / union - - # smallest enclosing box - xc1 = torch.min(x1, x1g) - yc1 = torch.min(y1, y1g) - xc2 = torch.max(x2, x2g) - yc2 = torch.max(y2, y2g) - diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps - - # centers of boxes - x_p = (x2 + x1) / 2 - y_p = (y2 + y1) / 2 - x_g = (x1g + x2g) / 2 - y_g = (y1g + y2g) / 2 - distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) - # width and height of boxes w_pred = x2 - x1 h_pred = y2 - y1 @@ -83,7 +62,7 @@ def complete_box_iou_loss( with torch.no_grad(): alpha = v / (1 - iou + v + eps) - loss = 1 - iou + (distance / diag_len) + alpha * v + loss = diou_loss + alpha * v if reduction == "mean": loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() elif reduction == "sum": diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index ea7ead19344..3fdc92f394a 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -1,7 +1,9 @@ +from typing import Tuple + import torch from ..utils import _log_api_usage_once -from .boxes import _upcast +from ._utils import _loss_inter_union, _upcast_non_float def distance_box_iou_loss( @@ -10,6 +12,7 @@ def distance_box_iou_loss( reduction: str = "none", eps: float = 1e-7, ) -> torch.Tensor: + """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the distance between boxes' centers isn't zero. Indeed, for two exactly overlapping @@ -37,29 +40,33 @@ def distance_box_iou_loss( https://arxiv.org/abs/1911.08287 """ - # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(distance_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) - x1, y1, x2, y2 = boxes1.unbind(dim=-1) - x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + loss, _ = _diou_iou_loss(boxes1, boxes2, eps) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + return loss - intsct = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps - iou = intsct / union +def _diou_iou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + eps: float = 1e-7, +) -> Tuple[torch.Tensor, torch.Tensor]: + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + intsct, union = _loss_inter_union(boxes1, boxes2) + iou = intsct / (union + eps) # smallest enclosing box xc1 = torch.min(x1, x1g) yc1 = torch.min(y1, y1g) @@ -67,7 +74,6 @@ def distance_box_iou_loss( yc2 = torch.max(y2, y2g) # The diagonal distance of the smallest enclosing box squared diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps - # centers of boxes x_p = (x2 + x1) / 2 y_p = (y2 + y1) / 2 @@ -75,12 +81,7 @@ def distance_box_iou_loss( y_g = (y1g + y2g) / 2 # The distance between boxes' centers squared. centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) - # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) - if reduction == "mean": - loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() - elif reduction == "sum": - loss = loss.sum() - return loss + return loss, iou diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 4d6f946f5e8..f3d9172f833 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,14 +1,7 @@ import torch -from torch import Tensor from ..utils import _log_api_usage_once - - -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.dtype not in (torch.float32, torch.float64): - return t.float() - return t +from ._utils import _upcast_non_float, _loss_inter_union def generalized_box_iou_loss( @@ -17,10 +10,8 @@ def generalized_box_iou_loss( reduction: str = "none", eps: float = 1e-7, ) -> torch.Tensor: - """ - Original implementation from - https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap and scales with the size of their smallest enclosing box. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. @@ -38,29 +29,26 @@ def generalized_box_iou_loss( ``'sum'``: The output will be summed. Default: ``'none'`` eps (float): small number to prevent division by zero. Default: 1e-7 + Returns: + Tensor: Loss tensor with the reduction option applied. + Reference: Hamid Rezatofighi et. al: Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression: https://arxiv.org/abs/1902.09630 """ + + # Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsctk = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + intsctk, unionk = _loss_inter_union(boxes1, boxes2) iouk = intsctk / (unionk + eps) # smallest enclosing box From bc6d0e5899605de572042692ae9611c4c5b271c0 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 18 May 2022 00:15:18 +0530 Subject: [PATCH 2/2] Address nits --- torchvision/ops/diou_loss.py | 4 ++-- torchvision/ops/giou_loss.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index 3fdc92f394a..2187aea4cc5 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -63,11 +63,11 @@ def _diou_iou_loss( eps: float = 1e-7, ) -> Tuple[torch.Tensor, torch.Tensor]: - x1, y1, x2, y2 = boxes1.unbind(dim=-1) - x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) intsct, union = _loss_inter_union(boxes1, boxes2) iou = intsct / (union + eps) # smallest enclosing box + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) xc1 = torch.min(x1, x1g) yc1 = torch.min(y1, y1g) xc2 = torch.max(x2, x2g) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index f3d9172f833..a7210f5739b 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -45,12 +45,12 @@ def generalized_box_iou_loss( boxes1 = _upcast_non_float(boxes1) boxes2 = _upcast_non_float(boxes2) - x1, y1, x2, y2 = boxes1.unbind(dim=-1) - x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - intsctk, unionk = _loss_inter_union(boxes1, boxes2) iouk = intsctk / (unionk + eps) + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + # smallest enclosing box xc1 = torch.min(x1, x1g) yc1 = torch.min(y1, y1g)