-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added CIOU loss function #5776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added CIOU loss function #5776
Changes from 6 commits
d868ec1
abb09eb
9c2ee2e
f3f1d92
2d0f627
a158ca3
5760487
56147d2
38e7a19
9a1cf90
755fa07
1a6b59a
99a3951
d89dbec
c9b0cab
8c2feee
b1d33fa
c531b1d
19b23d1
916418f
96c6dda
844e0da
38f9ede
2422913
ada4471
b8a7d96
c8a18ce
9b4803a
5cf1591
d25a5a0
14add84
03ecb91
1c4ae7f
2cbc6a2
9c88d92
e36fb15
1e57b6b
7e244fb
47c7e09
f5a352c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import math | ||
|
||
import torch | ||
|
||
|
||
def ciou_loss( | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
boxes1: torch.Tensor, | ||
boxes2: torch.Tensor, | ||
reduction: str = "none", | ||
eps: float = 1e-7, | ||
) -> torch.Tensor: | ||
|
||
""" | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Original Implementation from | ||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes | ||
boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes | ||
reduction : (string, optional) Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be | ||
applied to the output. ``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'`` | ||
eps : (float, optional): small number to prevent division by zero. Default: 1e-7 | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Reference: | ||
|
||
Complete Intersection over Union Loss (Zhaohui Zheng et. al) | ||
https://arxiv.org/abs/1911.08287 | ||
|
||
""" | ||
|
||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x1, y1, x2, y2 = boxes1.unbind(dim=-1) | ||
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) | ||
|
||
if (x2 < x1).all(): | ||
raise ValueError("x1 is larger than x2") | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (y2 < y1).all(): | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError("y1 is larger than y2") | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Intersection keypoints | ||
xkis1 = torch.max(x1, x1g) | ||
ykis1 = torch.max(y1, y1g) | ||
xkis2 = torch.min(x2, x2g) | ||
ykis2 = torch.min(y2, y2g) | ||
|
||
intsct = torch.zeros_like(x1) | ||
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | ||
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | ||
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Moreover it's worth noting that cIoU and dIoU share a large number of common code that could be shared. |
||
iou = intsct / union | ||
|
||
# smallest enclosing box | ||
xc1 = torch.min(x1, x1g) | ||
yc1 = torch.min(y1, y1g) | ||
xc2 = torch.max(x2, x2g) | ||
yc2 = torch.max(y2, y2g) | ||
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps | ||
|
||
# centers of boxes | ||
x_p = (x2 + x1) / 2 | ||
y_p = (y2 + y1) / 2 | ||
x_g = (x1g + x2g) / 2 | ||
y_g = (y1g + y2g) / 2 | ||
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) | ||
|
||
# width and height of boxes | ||
w_pred = x2 - x1 | ||
h_pred = y2 - y1 | ||
w_gt = x2g - x1g | ||
h_gt = y2g - y1g | ||
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with torch.no_grad(): | ||
alpha = v / (1 - iou + v + eps) | ||
|
||
# Eqn. (10) | ||
loss = 1 - iou + (distance / diag_len) + alpha * v | ||
if reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
|
||
return loss |
Uh oh!
There was an error while loading. Please reload this page.