Skip to content

Cleanup ops #6024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
remove_small_boxes,
clip_boxes_to_image,
box_area,
box_convert,
box_iou,
generalized_box_iou,
distance_box_iou,
complete_box_iou,
masks_to_boxes,
)
from .boxes import box_convert
from .ciou_loss import complete_box_iou_loss
from .deform_conv import deform_conv2d, DeformConv2d
from .diou_loss import distance_box_iou_loss
Expand Down
37 changes: 37 additions & 0 deletions torchvision/ops/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,40 @@ def split_normalization_params(
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params


def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


def _upcast_non_float(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t


def _loss_inter_union(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk

return intsctk, unionk
40 changes: 9 additions & 31 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..utils import _log_api_usage_once
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
from ._utils import _upcast


def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
Expand Down Expand Up @@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
return boxes


def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.is_floating_point():
return t if t.dtype in (torch.float32, torch.float64) else t.float()
else:
return t if t.dtype in (torch.int32, torch.int64) else t.int()


def box_area(boxes: Tensor) -> Tensor:
"""
Computes the area of a set of bounding boxes, which are specified by their
Expand Down Expand Up @@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])

whi = (rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps

# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
diou, iou = _box_diou_iou(boxes1, boxes2, eps)

w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1]
Expand All @@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
return diou - alpha * v


def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
Expand All @@ -380,27 +358,27 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2)
return diou

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:

iou = box_iou(boxes1, boxes2)
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps

# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared)
return iou - (centers_distance_squared / diagonal_distance_squared), iou


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
Expand Down
47 changes: 13 additions & 34 deletions torchvision/ops/ciou_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from ..utils import _log_api_usage_once
from .giou_loss import _upcast
from ._utils import _upcast_non_float
from .diou_loss import _diou_iou_loss


def complete_box_iou_loss(
Expand Down Expand Up @@ -30,50 +31,28 @@ def complete_box_iou_loss(
``'sum'``: The output will be summed. Default: ``'none'``
eps : (float): small number to prevent division by zero. Default: 1e-7

Reference:
Returns:
Tensor: Loss tensor with the reduction option applied.

Complete Intersection over Union Loss (Zhaohui Zheng et. al)
https://arxiv.org/abs/1911.08287
Reference:
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
https://arxiv.org/abs/1911.08287

"""

# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(complete_box_iou_loss)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast_non_float(boxes2)

diou_loss, iou = _diou_iou_loss(boxes1, boxes2)

x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union

# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps

# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)

# width and height of boxes
w_pred = x2 - x1
h_pred = y2 - y1
Expand All @@ -83,7 +62,7 @@ def complete_box_iou_loss(
with torch.no_grad():
alpha = v / (1 - iou + v + eps)

loss = 1 - iou + (distance / diag_len) + alpha * v
loss = diou_loss + alpha * v
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
Expand Down
47 changes: 24 additions & 23 deletions torchvision/ops/diou_loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Tuple

import torch

from ..utils import _log_api_usage_once
from .boxes import _upcast
from ._utils import _loss_inter_union, _upcast_non_float


def distance_box_iou_loss(
Expand All @@ -10,6 +12,7 @@ def distance_box_iou_loss(
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:

"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
Expand Down Expand Up @@ -37,50 +40,48 @@ def distance_box_iou_loss(
https://arxiv.org/abs/1911.08287
"""

# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou_loss)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast_non_float(boxes2)

x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
loss, _ = _diou_iou_loss(boxes1, boxes2, eps)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss

intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union

def _diou_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
eps: float = 1e-7,
) -> Tuple[torch.Tensor, torch.Tensor]:

intsct, union = _loss_inter_union(boxes1, boxes2)
iou = intsct / (union + eps)
# smallest enclosing box
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
# The diagonal distance of the smallest enclosing box squared
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps

# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
# The distance between boxes' centers squared.
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
return loss, iou
38 changes: 13 additions & 25 deletions torchvision/ops/giou_loss.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import torch
from torch import Tensor

from ..utils import _log_api_usage_once


def _upcast(t: Tensor) -> Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
if t.dtype not in (torch.float32, torch.float64):
return t.float()
return t
from ._utils import _upcast_non_float, _loss_inter_union


def generalized_box_iou_loss(
Expand All @@ -17,10 +10,8 @@ def generalized_box_iou_loss(
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Original implementation from
https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py

"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap and scales with the size of their smallest enclosing box.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Expand All @@ -38,31 +29,28 @@ def generalized_box_iou_loss(
``'sum'``: The output will be summed. Default: ``'none'``
eps (float): small number to prevent division by zero. Default: 1e-7

Returns:
Tensor: Loss tensor with the reduction option applied.

Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union:
A Metric and A Loss for Bounding Box Regression:
https://arxiv.org/abs/1902.09630
"""

# Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py

if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(generalized_box_iou_loss)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
boxes1 = _upcast_non_float(boxes1)
boxes2 = _upcast_non_float(boxes2)
intsctk, unionk = _loss_inter_union(boxes1, boxes2)
iouk = intsctk / (unionk + eps)

x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)

# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
Expand Down