Skip to content

Commit efdd3fd

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Distance IoU (#5786)
Summary: * [FEAT] Add distance IoU and distance IoU loss + some tests (WIP for tests). * [FIX] Remove URL from docstring + remove assert since it causes a big performance drop. * [FIX] eps isn't None. * [TEST] Update existing box dIoU test + add dIoU loss tests (inspired from cIoU ones). * [ENH] Some pre-commit fixes + remove print + mypy. * [ENH] Pass the device in the assertion for the dIoU loss test. * [FIX] Remove type hints from the dIoU box test. * [ENH] Refactor box and loss for dIoU functions + fix half tests. * [FIX] Precommits fix. * [ENH] Some improvement for the distance IoU tests thanks to code review. * [ENH] Upcast in distance boxes computation to avoid overflow. * [ENH] Revert the refactor of distance IoU loss back since it introduced a bug and can be slow. * Precommit fix. * [FIX] Few changes introduced by merge conflict. * Add code reference * Fix test Reviewed By: YosuaMichael Differential Revision: D36281596 fbshipit-source-id: 70e5102ec6fae9c9795d1895911f94f0a74e42f8 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent caf895f commit efdd3fd

File tree

6 files changed

+231
-1
lines changed

6 files changed

+231
-1
lines changed

docs/source/ops.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Operators
2626
drop_block3d
2727
generalized_box_iou
2828
generalized_box_iou_loss
29+
distance_box_iou
30+
distance_box_iou_loss
2931
masks_to_boxes
3032
nms
3133
ps_roi_align

test/test_ops.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,97 @@ def test_giou_jit(self) -> None:
12581258
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
12591259

12601260

1261+
class TestDistanceBoxIoU(BoxTestBase):
1262+
def _target_fn(self):
1263+
return (True, ops.distance_box_iou)
1264+
1265+
def _generate_int_input():
1266+
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1267+
1268+
def _generate_int_expected():
1269+
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1270+
1271+
def _generate_float_input():
1272+
return [
1273+
[285.3538, 185.5758, 1193.5110, 851.4551],
1274+
[285.1472, 188.7374, 1192.4984, 851.0669],
1275+
[279.2440, 197.9812, 1189.4746, 849.2019],
1276+
]
1277+
1278+
def _generate_float_expected():
1279+
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1280+
1281+
@pytest.mark.parametrize(
1282+
"test_input, dtypes, tolerance, expected",
1283+
[
1284+
pytest.param(
1285+
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
1286+
),
1287+
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
1288+
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
1289+
],
1290+
)
1291+
def test_distance_iou(self, test_input, dtypes, tolerance, expected):
1292+
self._run_test(test_input, dtypes, tolerance, expected)
1293+
1294+
def test_distance_iou_jit(self):
1295+
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
1296+
1297+
1298+
@pytest.mark.parametrize("device", cpu_and_gpu())
1299+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1300+
def test_distance_iou_loss(dtype, device):
1301+
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
1302+
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
1303+
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
1304+
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)
1305+
1306+
box1s = torch.stack(
1307+
[box2, box2],
1308+
dim=0,
1309+
)
1310+
box2s = torch.stack(
1311+
[box3, box4],
1312+
dim=0,
1313+
)
1314+
1315+
def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"):
1316+
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction)
1317+
# TODO: When passing the dtype, the torch.half fails as usual.
1318+
expected_output = torch.tensor(expected_output, device=device)
1319+
tol = 1e-5 if dtype != torch.half else 1e-3
1320+
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)
1321+
1322+
assert_distance_iou_loss(box1, box1, 0.0)
1323+
1324+
assert_distance_iou_loss(box1, box2, 0.8125)
1325+
1326+
assert_distance_iou_loss(box1, box3, 1.1923)
1327+
1328+
assert_distance_iou_loss(box1, box4, 1.2500)
1329+
1330+
assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean")
1331+
assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum")
1332+
1333+
1334+
@pytest.mark.parametrize("device", cpu_and_gpu())
1335+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1336+
def test_empty_distance_iou_inputs(dtype, device) -> None:
1337+
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1338+
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
1339+
1340+
loss = ops.distance_box_iou_loss(box1, box2, reduction="mean")
1341+
loss.backward()
1342+
1343+
tol = 1e-3 if dtype is torch.half else 1e-5
1344+
torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol)
1345+
assert box1.grad is not None, "box1.grad should not be None after backward is called"
1346+
assert box2.grad is not None, "box2.grad should not be None after backward is called"
1347+
1348+
loss = ops.distance_box_iou_loss(box1, box2, reduction="none")
1349+
assert loss.numel() == 0, "diou_loss for two empty box should be empty"
1350+
1351+
12611352
class TestCompleteBoxIou(BoxTestBase):
12621353
def _target_fn(self) -> Tuple[bool, Callable]:
12631354
return (True, ops.complete_box_iou)
@@ -1676,6 +1767,7 @@ def test_ciou_loss(self, dtype, device):
16761767
def assert_ciou_loss(box1, box2, expected_output, reduction="none"):
16771768

16781769
output = ops.complete_box_iou_loss(box1, box2, reduction=reduction)
1770+
# TODO: When passing the dtype, the torch.half test doesn't pass...
16791771
expected_output = torch.tensor(expected_output, device=device)
16801772
tol = 1e-5 if dtype != torch.half else 1e-3
16811773
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)

torchvision/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
box_area,
88
box_iou,
99
generalized_box_iou,
10+
distance_box_iou,
1011
complete_box_iou,
1112
masks_to_boxes,
1213
)
1314
from .boxes import box_convert
1415
from .ciou_loss import complete_box_iou_loss
1516
from .deform_conv import deform_conv2d, DeformConv2d
17+
from .diou_loss import distance_box_iou_loss
1618
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
1719
from .feature_pyramid_network import FeaturePyramidNetwork
1820
from .focal_loss import sigmoid_focal_loss
@@ -40,6 +42,8 @@
4042
"box_area",
4143
"box_iou",
4244
"generalized_box_iou",
45+
"distance_box_iou",
46+
"complete_box_iou",
4347
"roi_align",
4448
"RoIAlign",
4549
"roi_pool",
@@ -58,6 +62,8 @@
5862
"Conv3dNormActivation",
5963
"SqueezeExcitation",
6064
"generalized_box_iou_loss",
65+
"distance_box_iou_loss",
66+
"complete_box_iou_loss",
6167
"drop_block2d",
6268
"DropBlock2d",
6369
"drop_block3d",

torchvision/ops/boxes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,50 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
359359
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
360360

361361

362+
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
363+
"""
364+
Return distance intersection-over-union (Jaccard index) between two sets of boxes.
365+
366+
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
367+
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
368+
369+
Args:
370+
boxes1 (Tensor[N, 4]): first set of boxes
371+
boxes2 (Tensor[M, 4]): second set of boxes
372+
eps (float, optional): small number to prevent division by zero. Default: 1e-7
373+
374+
Returns:
375+
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
376+
for every element in boxes1 and boxes2
377+
"""
378+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
379+
_log_api_usage_once(distance_box_iou)
380+
381+
boxes1 = _upcast(boxes1)
382+
boxes2 = _upcast(boxes2)
383+
384+
inter, union = _box_inter_union(boxes1, boxes2)
385+
iou = inter / union
386+
387+
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
388+
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
389+
390+
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
391+
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
392+
393+
# centers of boxes
394+
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
395+
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
396+
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
397+
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
398+
# The distance between boxes' centers squared.
399+
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
400+
401+
# The distance IoU is the IoU penalized by a normalized
402+
# distance between boxes' centers squared.
403+
return iou - (centers_distance_squared / diagonal_distance_squared)
404+
405+
362406
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
363407
"""
364408
Compute the bounding boxes around the provided masks.

torchvision/ops/diou_loss.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from ..utils import _log_api_usage_once
4+
from .boxes import _upcast
5+
6+
7+
def distance_box_iou_loss(
8+
boxes1: torch.Tensor,
9+
boxes2: torch.Tensor,
10+
reduction: str = "none",
11+
eps: float = 1e-7,
12+
) -> torch.Tensor:
13+
"""
14+
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
15+
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
16+
boxes, the distance IoU is the same as the IoU loss.
17+
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
18+
19+
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
20+
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
21+
same dimensions.
22+
23+
Args:
24+
boxes1 (Tensor[N, 4]): first set of boxes
25+
boxes2 (Tensor[N, 4]): second set of boxes
26+
reduction (string, optional): Specifies the reduction to apply to the output:
27+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
28+
applied to the output. ``'mean'``: The output will be averaged.
29+
``'sum'``: The output will be summed. Default: ``'none'``
30+
eps (float, optional): small number to prevent division by zero. Default: 1e-7
31+
32+
Returns:
33+
Tensor: Loss tensor with the reduction option applied.
34+
35+
Reference:
36+
Zhaohui Zheng et. al: Distance Intersection over Union Loss:
37+
https://arxiv.org/abs/1911.08287
38+
"""
39+
40+
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
41+
42+
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
43+
_log_api_usage_once(distance_box_iou_loss)
44+
45+
boxes1 = _upcast(boxes1)
46+
boxes2 = _upcast(boxes2)
47+
48+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
49+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
50+
51+
# Intersection keypoints
52+
xkis1 = torch.max(x1, x1g)
53+
ykis1 = torch.max(y1, y1g)
54+
xkis2 = torch.min(x2, x2g)
55+
ykis2 = torch.min(y2, y2g)
56+
57+
intsct = torch.zeros_like(x1)
58+
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
59+
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
60+
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
61+
iou = intsct / union
62+
63+
# smallest enclosing box
64+
xc1 = torch.min(x1, x1g)
65+
yc1 = torch.min(y1, y1g)
66+
xc2 = torch.max(x2, x2g)
67+
yc2 = torch.max(y2, y2g)
68+
# The diagonal distance of the smallest enclosing box squared
69+
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
70+
71+
# centers of boxes
72+
x_p = (x2 + x1) / 2
73+
y_p = (y2 + y1) / 2
74+
x_g = (x1g + x2g) / 2
75+
y_g = (y1g + y2g) / 2
76+
# The distance between boxes' centers squared.
77+
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
78+
79+
# The distance IoU is the IoU penalized by a normalized
80+
# distance between boxes' centers squared.
81+
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
82+
if reduction == "mean":
83+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
84+
elif reduction == "sum":
85+
loss = loss.sum()
86+
return loss

torchvision/ops/giou_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def generalized_box_iou_loss(
3636
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
3737
applied to the output. ``'mean'``: The output will be averaged.
3838
``'sum'``: The output will be summed. Default: ``'none'``
39-
eps (float, optional): small number to prevent division by zero. Default: 1e-7
39+
eps (float): small number to prevent division by zero. Default: 1e-7
4040
4141
Reference:
4242
Hamid Rezatofighi et. al: Generalized Intersection over Union:

0 commit comments

Comments
 (0)