Skip to content

Commit e2e1db6

Browse files
datumboxabhi-glitchhg
authored andcommitted
[fbsync] Fix d/c IoU for different batch sizes (#6338)
Summary: * Fix bug in calculating cIoU for unequal sizes * Remove comment * what the epsilon? * Fixing DIoU * Optimization by Francisco. * Fix the expected values on CompleteBoxIoU * Apply suggestions from code review * Adding cartesian product test. * remove static Reviewed By: NicolasHug Differential Revision: D38351751 fbshipit-source-id: 097e5f7048c650767e275fbb2c30ed0c800b1314 Co-authored-by: Abhijit Deo <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Abhijit Deo <[email protected]>
1 parent d893469 commit e2e1db6

File tree

3 files changed

+99
-49
lines changed

3 files changed

+99
-49
lines changed

test/test_ops.py

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,14 +1111,6 @@ def test_bbox_convert_jit(self):
11111111
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)
11121112

11131113

1114-
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1115-
FLOAT_BOXES = [
1116-
[285.3538, 185.5758, 1193.5110, 851.4551],
1117-
[285.1472, 188.7374, 1192.4984, 851.0669],
1118-
[279.2440, 197.9812, 1189.4746, 849.2019],
1119-
]
1120-
1121-
11221114
class TestBoxArea:
11231115
def area_check(self, box, expected, atol=1e-4):
11241116
out = ops.box_area(box)
@@ -1152,99 +1144,155 @@ def test_box_area_jit(self):
11521144
torch.testing.assert_close(scripted_area, expected)
11531145

11541146

1147+
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
1148+
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1149+
FLOAT_BOXES = [
1150+
[285.3538, 185.5758, 1193.5110, 851.4551],
1151+
[285.1472, 188.7374, 1192.4984, 851.0669],
1152+
[279.2440, 197.9812, 1189.4746, 849.2019],
1153+
]
1154+
1155+
1156+
def gen_box(size, dtype=torch.float):
1157+
xy1 = torch.rand((size, 2), dtype=dtype)
1158+
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
1159+
return torch.cat([xy1, xy2], axis=-1)
1160+
1161+
11551162
class TestIouBase:
11561163
@staticmethod
1157-
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List):
1164+
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
11581165
for dtype in dtypes:
1159-
actual_box = torch.tensor(test_input, dtype=dtype)
1166+
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
1167+
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
11601168
expected_box = torch.tensor(expected)
1161-
out = target_fn(actual_box, actual_box)
1169+
out = target_fn(actual_box1, actual_box2)
11621170
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)
11631171

11641172
@staticmethod
1165-
def _run_jit_test(target_fn: Callable, test_input: List):
1166-
box_tensor = torch.tensor(test_input, dtype=torch.float)
1173+
def _run_jit_test(target_fn: Callable, actual_box: List):
1174+
box_tensor = torch.tensor(actual_box, dtype=torch.float)
11671175
expected = target_fn(box_tensor, box_tensor)
11681176
scripted_fn = torch.jit.script(target_fn)
11691177
scripted_out = scripted_fn(box_tensor, box_tensor)
11701178
torch.testing.assert_close(scripted_out, expected)
11711179

1180+
@staticmethod
1181+
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
1182+
N = boxes1.size(0)
1183+
M = boxes2.size(0)
1184+
result = torch.zeros((N, M))
1185+
for i in range(N):
1186+
for j in range(M):
1187+
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
1188+
return result
1189+
1190+
@staticmethod
1191+
def _run_cartesian_test(target_fn: Callable):
1192+
boxes1 = gen_box(5)
1193+
boxes2 = gen_box(7)
1194+
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
1195+
b = target_fn(boxes1, boxes2)
1196+
assert torch.allclose(a, b)
1197+
11721198

11731199
class TestBoxIou(TestIouBase):
1174-
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1200+
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
11751201
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
11761202

11771203
@pytest.mark.parametrize(
1178-
"test_input, dtypes, atol, expected",
1204+
"actual_box1, actual_box2, dtypes, atol, expected",
11791205
[
1180-
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1181-
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1182-
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1206+
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1207+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1208+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
11831209
],
11841210
)
1185-
def test_iou(self, test_input, dtypes, atol, expected):
1186-
self._run_test(ops.box_iou, test_input, dtypes, atol, expected)
1211+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1212+
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)
11871213

11881214
def test_iou_jit(self):
11891215
self._run_jit_test(ops.box_iou, INT_BOXES)
11901216

1217+
def test_iou_cartesian(self):
1218+
self._run_cartesian_test(ops.box_iou)
1219+
11911220

11921221
class TestGeneralizedBoxIou(TestIouBase):
1193-
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]
1222+
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
11941223
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
11951224

11961225
@pytest.mark.parametrize(
1197-
"test_input, dtypes, atol, expected",
1226+
"actual_box1, actual_box2, dtypes, atol, expected",
11981227
[
1199-
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1200-
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1201-
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1228+
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1229+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1230+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
12021231
],
12031232
)
1204-
def test_iou(self, test_input, dtypes, atol, expected):
1205-
self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected)
1233+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1234+
self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
12061235

12071236
def test_iou_jit(self):
12081237
self._run_jit_test(ops.generalized_box_iou, INT_BOXES)
12091238

1239+
def test_iou_cartesian(self):
1240+
self._run_cartesian_test(ops.generalized_box_iou)
1241+
12101242

12111243
class TestDistanceBoxIoU(TestIouBase):
1212-
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1244+
int_expected = [
1245+
[1.0000, 0.1875, -0.4444],
1246+
[0.1875, 1.0000, -0.5625],
1247+
[-0.4444, -0.5625, 1.0000],
1248+
[-0.0781, 0.1875, -0.6267],
1249+
]
12131250
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
12141251

12151252
@pytest.mark.parametrize(
1216-
"test_input, dtypes, atol, expected",
1253+
"actual_box1, actual_box2, dtypes, atol, expected",
12171254
[
1218-
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1219-
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1220-
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1255+
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1256+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1257+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
12211258
],
12221259
)
1223-
def test_iou(self, test_input, dtypes, atol, expected):
1224-
self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected)
1260+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1261+
self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
12251262

12261263
def test_iou_jit(self):
12271264
self._run_jit_test(ops.distance_box_iou, INT_BOXES)
12281265

1266+
def test_iou_cartesian(self):
1267+
self._run_cartesian_test(ops.distance_box_iou)
1268+
12291269

12301270
class TestCompleteBoxIou(TestIouBase):
1231-
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1271+
int_expected = [
1272+
[1.0000, 0.1875, -0.4444],
1273+
[0.1875, 1.0000, -0.5625],
1274+
[-0.4444, -0.5625, 1.0000],
1275+
[-0.0781, 0.1875, -0.6267],
1276+
]
12321277
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
12331278

12341279
@pytest.mark.parametrize(
1235-
"test_input, dtypes, atol, expected",
1280+
"actual_box1, actual_box2, dtypes, atol, expected",
12361281
[
1237-
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1238-
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1239-
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
1282+
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
1283+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
1284+
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
12401285
],
12411286
)
1242-
def test_iou(self, test_input, dtypes, atol, expected):
1243-
self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected)
1287+
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
1288+
self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)
12441289

12451290
def test_iou_jit(self):
12461291
self._run_jit_test(ops.complete_box_iou, INT_BOXES)
12471292

1293+
def test_iou_cartesian(self):
1294+
self._run_cartesian_test(ops.complete_box_iou)
1295+
12481296

12491297
def get_boxes(dtype, device):
12501298
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)

torchvision/ops/boxes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
325325

326326
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
327327

328-
w_pred = boxes1[:, 2] - boxes1[:, 0]
329-
h_pred = boxes1[:, 3] - boxes1[:, 1]
328+
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
329+
h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]
330330

331331
w_gt = boxes2[:, 2] - boxes2[:, 0]
332332
h_gt = boxes2[:, 3] - boxes2[:, 1]
333333

334-
v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
334+
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
335335
with torch.no_grad():
336336
alpha = v / (1 - iou + v + eps)
337337
return diou - alpha * v
@@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
358358

359359
boxes1 = _upcast(boxes1)
360360
boxes2 = _upcast(boxes2)
361-
diou, _ = _box_diou_iou(boxes1, boxes2)
361+
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
362362
return diou
363363

364364

@@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
375375
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
376376
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
377377
# The distance between boxes' centers squared.
378-
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
378+
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
379+
_upcast((y_p[:, None] - y_g[None, :])) ** 2
380+
)
379381
# The distance IoU is the IoU penalized by a normalized
380382
# distance between boxes' centers squared.
381383
return iou - (centers_distance_squared / diagonal_distance_squared), iou

torchvision/ops/ciou_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def complete_box_iou_loss(
1414

1515
"""
1616
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
17-
boxes do not overlap overlap area, This loss function considers important geometrical
18-
factors such as overlap area, normalized central point distance and aspect ratio.
17+
boxes do not overlap. This loss function considers important geometrical
18+
factors such as overlap area, normalized central point distance and aspect ratio.
1919
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
2020
2121
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
@@ -35,7 +35,7 @@ def complete_box_iou_loss(
3535
Tensor: Loss tensor with the reduction option applied.
3636
3737
Reference:
38-
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
38+
Zhaohui Zheng et al.: Complete Intersection over Union Loss:
3939
https://arxiv.org/abs/1911.08287
4040
4141
"""

0 commit comments

Comments
 (0)