Skip to content

Distance IoU #5786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
135763c
[FEAT] Add distance IoU and distance IoU loss + some tests (WIP for t…
Apr 7, 2022
ec599d2
[FIX] Remove URL from docstring + remove assert since it causes a big…
Apr 7, 2022
41703e6
[FIX] eps isn't None.
Apr 7, 2022
51616ed
[TEST] Update existing box dIoU test + add dIoU loss tests (inspired …
Apr 13, 2022
ee37c8d
Merge branch 'main' into dIoU
Apr 13, 2022
7631ab7
[ENH] Some pre-commit fixes + remove print + mypy.
Apr 13, 2022
b744d6d
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 13, 2022
8ceffcc
[ENH] Pass the device in the assertion for the dIoU loss test.
Apr 13, 2022
bc65b83
Merge branch 'main' into dIoU
Apr 14, 2022
a4e58b7
[FIX] Remove type hints from the dIoU box test.
Apr 14, 2022
4ba5cdc
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 14, 2022
0ead2c3
[ENH] Refactor box and loss for dIoU functions + fix half tests.
Apr 21, 2022
a2702f8
[FIX] Precommits fix.
Apr 21, 2022
27894ef
Merge branch 'main' into dIoU
Apr 21, 2022
d4bd825
Merge branch 'main' of github.com:yassineAlouini/vision-1 into dIoU
Apr 26, 2022
497a7c1
[ENH] Some improvement for the distance IoU tests thanks to code review.
Apr 26, 2022
d8b7f35
Merge branch 'dIoU' of github.com:yassineAlouini/vision-1 into dIoU
Apr 26, 2022
a054032
[ENH] Upcast in distance boxes computation to avoid overflow.
Apr 26, 2022
d7baa67
[ENH] Revert the refactor of distance IoU loss back since it introduc…
Apr 26, 2022
4213ee4
Precommit fix.
Apr 26, 2022
1a2d6ab
Merge main and fix conflicts + make code iso with cIoU.
May 2, 2022
2856947
[FIX] Few changes introduced by merge conflict.
May 2, 2022
3a9d3d7
Add code reference
datumbox May 9, 2022
13fa495
Merge branch 'main' into dIoU
datumbox May 9, 2022
1b2f1e6
Fix test
datumbox May 9, 2022
ab44428
Merge branch 'main' into dIoU
datumbox May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Operators
drop_block3d
generalized_box_iou
generalized_box_iou_loss
distance_box_iou
distance_box_iou_loss
masks_to_boxes
nms
ps_roi_align
Expand Down
80 changes: 80 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,8 @@ def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Te
def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
def assert_close(box: Tensor, expected: Tensor, tolerance):
out = self._perform_box_operation(box)
print("The computed box is: ", out)
print("The expected one is: ", expected)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

for dtype in dtypes:
Expand Down Expand Up @@ -1257,6 +1259,84 @@ def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: f
def test_giou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])

class TestDistanceBoxIoU(BoxTestBase):
def _target_fn(self) -> Tuple[bool, Callable]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general question to torchvision maintainers
Do we type hint in tests?
It's not bad to type hint. But it isn't something which we follow either 😅

cc @pmeier @datumbox

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't. See #5563 (comment). The reason is that without also checking the tests with mypy they might go out of date and that is usually more harmful than not having them at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am removing the type hints then, thanks for the link @pmeier. 👍

return (True, ops.distance_box_iou)

def _generate_int_input() -> List[List[int]]:
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]

# TODO: Update this.
def _generate_int_expected() -> List[List[float]]:
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_input() -> List[List[float]]:
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]

def _generate_float_expected() -> List[List[float]]:
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
],
)
def test_distance_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
self._run_test(test_input, dtypes, tolerance, expected)

def test_distance_iou_jit(self) -> None:
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])

class TestDistanceIoULoss:
# Inspired and adapted from:
# https://github.com/pytorch/vision/pull/5776/files#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156

@staticmethod
def assert_distance_iou_loss(box1, box2, expected_output, dtype, reduction="none"):
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction)
expected_output = torch.tensor(expected_output, dtype=dtype)
tol = 1e-5 if dtype != torch.half else 1e-3
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol)

# TODO: torch.half as a dtype doesn't pass the test, investigate...
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_distance_iou_loss(self, dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)

box1s = torch.stack(
[box2, box2],
dim=0,
)
box2s = torch.stack(
[box3, box4],
dim=0,
)


self.assert_distance_iou_loss(box1, box1, 0.0, dtype)

self.assert_distance_iou_loss(box1, box2, 0.8125, dtype)

self.assert_distance_iou_loss(box1, box3, 1.1923, dtype)

self.assert_distance_iou_loss(box1, box4, 1.2500, dtype)

self.assert_distance_iou_loss(box1s, box2s, 1.2250, dtype, reduction="mean")
self.assert_distance_iou_loss(box1s, box2s, 2.4500, dtype, reduction="sum")


class TestMasksToBoxes:
def test_masks_box(self):
Expand Down
4 changes: 4 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
box_area,
box_iou,
generalized_box_iou,
distance_box_iou,
masks_to_boxes,
)
from .boxes import box_convert
Expand All @@ -15,6 +16,7 @@
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
from .diou_loss import distance_box_iou_loss
from .misc import FrozenBatchNorm2d, Conv2dNormActivation, Conv3dNormActivation, SqueezeExcitation
from .poolers import MultiScaleRoIAlign
from .ps_roi_align import ps_roi_align, PSRoIAlign
Expand All @@ -38,6 +40,7 @@
"box_area",
"box_iou",
"generalized_box_iou",
"distance_box_iou",
"roi_align",
"RoIAlign",
"roi_pool",
Expand All @@ -56,6 +59,7 @@
"Conv3dNormActivation",
"SqueezeExcitation",
"generalized_box_iou_loss",
"distance_box_iou_loss",
"drop_block2d",
"DropBlock2d",
"drop_block3d",
Expand Down
44 changes: 44 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,50 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor:

return iou - (areai - union) / areai

# Implementation inspired from the generalized_box_iou one.
# TODO: Some refactoring and homogenization could be done with
# the loss function in diou_loss.
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps:float= 1e-7) -> Tensor:
"""
Return distance intersection-over-union (Jaccard index) between two sets of boxes.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.

Args:
boxes1 (Tensor[N, 4]): first set of boxes
boxes2 (Tensor[M, 4]): second set of boxes
eps (float, optional): small number to prevent division by zero. Default: 1e-7

Returns:
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values
for every element in boxes1 and boxes2
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou)

inter, union = _box_inter_union(boxes1, boxes2)
iou = inter / union

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps


# centers of boxes
x_p = boxes1[:, None, :2].sum() / 2
y_p = boxes1[:, None, 2:].sum() / 2
x_g = boxes2[:, :2].sum() / 2
y_g = boxes2[:, 2:].sum() / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @yassineAlouini , I think there is a problem with this implementation. The calculation of centre of boxes is not correct acc to me. We should be adding up only x1 x2 and y1 y2, ref .
But in current implementation, we are adding x1,y1 and x2, y2. (BBox shape is in form [x1,y1,x2,y2]) .

This can also be checked by calculating distance_box_iou_loss and distance_box_iou on a sample tensors.

import torch
from torchvision.ops import distance_box_iou, distance_box_iou_loss

box1 = torch.tensor([[-1, -1, 1, 1]], )
box2 = torch.tensor([[0, 0, 1, 1]],)

1-distance_box_iou(box1, box2)[0] == distance_box_iou_loss(box1, box2)

Last statement returns False. Ideally it should return True.

I suggest you to do following changes.

Suggested change
# centers of boxes
x_p = boxes1[:, None, :2].sum() / 2
y_p = boxes1[:, None, 2:].sum() / 2
x_g = boxes2[:, :2].sum() / 2
y_g = boxes2[:, 2:].sum() / 2
# centers of boxes
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2

@datumbox,
Please correct me if I'm wrong.
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I haven't yet reviewed the correctness of the implementation as we still discuss the structure/API.

I think that's probably a typo and @yassineAlouini intended to write something like:

x_p = boxes1[:, 0::2].sum() / 2
y_p = boxes1[:, 1::2].sum() / 2
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, I think I went too quickly over this and thought that the bounding box was in the x1x2y1y2 format. Thanks for pointing this out and your suggestions. 👍

# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared)


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Expand Down
84 changes: 84 additions & 0 deletions torchvision/ops/diou_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

import torch
import math
from .boxes import _upcast
from ..utils import _log_api_usage_once

# TODO: Some parts can be refactored between gIoU, cIoU, and dIoU.
def distance_box_iou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
boxes, the distance IoU is the same as the IoU loss.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the
same dimensions.

Args:
boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes
boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
Reference:
Zhaohui Zheng et. al: Distance Intersection over Union Loss:
https://arxiv.org/abs/1911.08287

"""
# Original implementation from:
# https://github.com/facebookresearch/detectron2/blob/dfe8d368c8b7cc2be42c5c3faf9bdcc3c08257b1/detectron2/layers/losses.py#L5
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(distance_box_iou_loss)

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)

# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)

intsct = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
iou = intsct / union

# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
# The diagonal distance of the smallest enclosing box squared
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps

# centers of boxes
x_p = (x2 + x1) / 2
y_p = (y2 + y1) / 2
x_g = (x1g + x2g) / 2
y_g = (y1g + y2g) / 2
# The distance between boxes' centers squared.
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)

# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
diou = iou - (centers_distance_squared / diagonal_distance_squared)
loss = 1 - diou
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()

return loss
2 changes: 1 addition & 1 deletion torchvision/ops/giou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def generalized_box_iou_loss(
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be
applied to the output. ``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``
eps (float, optional): small number to prevent division by zero. Default: 1e-7
eps (float): small number to prevent division by zero. Default: 1e-7

Reference:
Hamid Rezatofighi et. al: Generalized Intersection over Union:
Expand Down