-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Distance IoU #5786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distance IoU #5786
Changes from 5 commits
135763c
ec599d2
41703e6
51616ed
ee37c8d
7631ab7
b744d6d
8ceffcc
bc65b83
a4e58b7
4ba5cdc
0ead2c3
a2702f8
27894ef
d4bd825
497a7c1
d8b7f35
a054032
d7baa67
4213ee4
1a2d6ab
2856947
3a9d3d7
13fa495
1b2f1e6
ab44428
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1127,6 +1127,8 @@ def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Te | |
def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: | ||
def assert_close(box: Tensor, expected: Tensor, tolerance): | ||
out = self._perform_box_operation(box) | ||
print("The computed box is: ", out) | ||
print("The expected one is: ", expected) | ||
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) | ||
|
||
for dtype in dtypes: | ||
|
@@ -1257,6 +1259,84 @@ def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: f | |
def test_giou_jit(self) -> None: | ||
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
class TestDistanceBoxIoU(BoxTestBase): | ||
def _target_fn(self) -> Tuple[bool, Callable]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we don't. See #5563 (comment). The reason is that without also checking the tests with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am removing the type hints then, thanks for the link @pmeier. 👍 |
||
return (True, ops.distance_box_iou) | ||
|
||
def _generate_int_input() -> List[List[int]]: | ||
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] | ||
|
||
# TODO: Update this. | ||
def _generate_int_expected() -> List[List[float]]: | ||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||
|
||
def _generate_float_input() -> List[List[float]]: | ||
return [ | ||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||
] | ||
|
||
def _generate_float_expected() -> List[List[float]]: | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), | ||
], | ||
) | ||
def test_distance_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: | ||
self._run_test(test_input, dtypes, tolerance, expected) | ||
|
||
def test_distance_iou_jit(self) -> None: | ||
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
class TestDistanceIoULoss: | ||
yassineAlouini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Inspired and adapted from: | ||
# https://github.com/pytorch/vision/pull/5776/files#diff-d183f2afc51d6a59bc70094e8f476d2468c45e415500f6eb60abad955e065156 | ||
yassineAlouini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def assert_distance_iou_loss(box1, box2, expected_output, dtype, reduction="none"): | ||
output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) | ||
expected_output = torch.tensor(expected_output, dtype=dtype) | ||
tol = 1e-5 if dtype != torch.half else 1e-3 | ||
torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) | ||
yassineAlouini marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# TODO: torch.half as a dtype doesn't pass the test, investigate... | ||
@pytest.mark.parametrize("dtype", [torch.float32]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_distance_iou_loss(self, dtype, device): | ||
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) | ||
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) | ||
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) | ||
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) | ||
|
||
box1s = torch.stack( | ||
[box2, box2], | ||
dim=0, | ||
) | ||
box2s = torch.stack( | ||
[box3, box4], | ||
dim=0, | ||
) | ||
|
||
|
||
self.assert_distance_iou_loss(box1, box1, 0.0, dtype) | ||
|
||
self.assert_distance_iou_loss(box1, box2, 0.8125, dtype) | ||
|
||
self.assert_distance_iou_loss(box1, box3, 1.1923, dtype) | ||
|
||
self.assert_distance_iou_loss(box1, box4, 1.2500, dtype) | ||
|
||
self.assert_distance_iou_loss(box1s, box2s, 1.2250, dtype, reduction="mean") | ||
self.assert_distance_iou_loss(box1s, box2s, 2.4500, dtype, reduction="sum") | ||
|
||
|
||
class TestMasksToBoxes: | ||
def test_masks_box(self): | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -310,6 +310,50 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: | |||||||||||||||||||||
|
||||||||||||||||||||||
return iou - (areai - union) / areai | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Implementation inspired from the generalized_box_iou one. | ||||||||||||||||||||||
# TODO: Some refactoring and homogenization could be done with | ||||||||||||||||||||||
# the loss function in diou_loss. | ||||||||||||||||||||||
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps:float= 1e-7) -> Tensor: | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Return distance intersection-over-union (Jaccard index) between two sets of boxes. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||||||||||||||||||||||
``0 <= x1 < x2`` and ``0 <= y1 < y2``. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Args: | ||||||||||||||||||||||
boxes1 (Tensor[N, 4]): first set of boxes | ||||||||||||||||||||||
boxes2 (Tensor[M, 4]): second set of boxes | ||||||||||||||||||||||
eps (float, optional): small number to prevent division by zero. Default: 1e-7 | ||||||||||||||||||||||
|
||||||||||||||||||||||
Returns: | ||||||||||||||||||||||
Tensor[N, M]: the NxM matrix containing the pairwise distance IoU values | ||||||||||||||||||||||
for every element in boxes1 and boxes2 | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||||||||||||||||||||||
_log_api_usage_once(distance_box_iou) | ||||||||||||||||||||||
|
||||||||||||||||||||||
inter, union = _box_inter_union(boxes1, boxes2) | ||||||||||||||||||||||
iou = inter / union | ||||||||||||||||||||||
|
||||||||||||||||||||||
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) | ||||||||||||||||||||||
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] | ||||||||||||||||||||||
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
# centers of boxes | ||||||||||||||||||||||
x_p = boxes1[:, None, :2].sum() / 2 | ||||||||||||||||||||||
yassineAlouini marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
y_p = boxes1[:, None, 2:].sum() / 2 | ||||||||||||||||||||||
x_g = boxes2[:, :2].sum() / 2 | ||||||||||||||||||||||
y_g = boxes2[:, 2:].sum() / 2 | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @yassineAlouini , I think there is a problem with this implementation. The calculation of centre of boxes is not correct acc to me. We should be adding up only This can also be checked by calculating
Last statement returns I suggest you to do following changes.
Suggested change
@datumbox, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. I haven't yet reviewed the correctness of the implementation as we still discuss the structure/API. I think that's probably a typo and @yassineAlouini intended to write something like:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes indeed, I think I went too quickly over this and thought that the bounding box was in the |
||||||||||||||||||||||
# The distance between boxes' centers squared. | ||||||||||||||||||||||
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# The distance IoU is the IoU penalized by a normalized | ||||||||||||||||||||||
# distance between boxes' centers squared. | ||||||||||||||||||||||
return iou - (centers_distance_squared / diagonal_distance_squared) | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
|
||
import torch | ||
import math | ||
from .boxes import _upcast | ||
from ..utils import _log_api_usage_once | ||
|
||
# TODO: Some parts can be refactored between gIoU, cIoU, and dIoU. | ||
def distance_box_iou_loss( | ||
boxes1: torch.Tensor, | ||
boxes2: torch.Tensor, | ||
reduction: str = "none", | ||
eps: float = 1e-7, | ||
) -> torch.Tensor: | ||
""" | ||
Gradient-friendly IoU loss with an additional penalty that is non-zero when the | ||
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping | ||
boxes, the distance IoU is the same as the IoU loss. | ||
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. | ||
|
||
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the | ||
same dimensions. | ||
|
||
Args: | ||
boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes | ||
boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes | ||
reduction (string, optional): Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be | ||
applied to the output. ``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'`` | ||
eps (float, optional): small number to prevent division by zero. Default: 1e-7 | ||
Reference: | ||
Zhaohui Zheng et. al: Distance Intersection over Union Loss: | ||
https://arxiv.org/abs/1911.08287 | ||
|
||
""" | ||
# Original implementation from: | ||
# https://github.com/facebookresearch/detectron2/blob/dfe8d368c8b7cc2be42c5c3faf9bdcc3c08257b1/detectron2/layers/losses.py#L5 | ||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
_log_api_usage_once(distance_box_iou_loss) | ||
|
||
boxes1 = _upcast(boxes1) | ||
boxes2 = _upcast(boxes2) | ||
x1, y1, x2, y2 = boxes1.unbind(dim=-1) | ||
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) | ||
|
||
# Intersection keypoints | ||
xkis1 = torch.max(x1, x1g) | ||
ykis1 = torch.max(y1, y1g) | ||
xkis2 = torch.min(x2, x2g) | ||
ykis2 = torch.min(y2, y2g) | ||
|
||
intsct = torch.zeros_like(x1) | ||
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | ||
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | ||
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps | ||
iou = intsct / union | ||
|
||
# smallest enclosing box | ||
xc1 = torch.min(x1, x1g) | ||
yc1 = torch.min(y1, y1g) | ||
xc2 = torch.max(x2, x2g) | ||
yc2 = torch.max(y2, y2g) | ||
# The diagonal distance of the smallest enclosing box squared | ||
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps | ||
|
||
# centers of boxes | ||
x_p = (x2 + x1) / 2 | ||
y_p = (y2 + y1) / 2 | ||
x_g = (x1g + x2g) / 2 | ||
y_g = (y1g + y2g) / 2 | ||
# The distance between boxes' centers squared. | ||
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) | ||
|
||
# The distance IoU is the IoU penalized by a normalized | ||
# distance between boxes' centers squared. | ||
diou = iou - (centers_distance_squared / diagonal_distance_squared) | ||
loss = 1 - diou | ||
if reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
|
||
return loss |
Uh oh!
There was an error while loading. Please reload this page.