-
Notifications
You must be signed in to change notification settings - Fork 614
Add GIOU loss #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add GIOU loss #477
Conversation
54732a3
to
4568ba4
Compare
Hi @fsx950223 thank you for this contribution and #493 . Just wanted to let you know that we'll review ASAP, but are a bit backed up given the work that has been going into building packages for newly released TF2 packages. Also, just for the future.. if you wouldn't mind opening an issue to discuss before submitting a PR that'd be much appreciated. |
OK |
@Squadrick Hi, Dheeraj, could you take a look if you have time? Thanks :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @fsx950223 , sorry for the late reply. I have taken a deeper look into the paper, but want to give you some general feedback on this. Also, if it's possible, could you elaborate how to compute GIoU in test cases or any other places? Thanks!
I will modify the code on the weekend. |
Changed. @WindQAQ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Only some documentation issues need to be addressed. Thank you!
""" | ||
Args | ||
b1: bbox. | ||
b2: the other bbox. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we rename bbox to bounding box here? It should be clearer in docstring.
Changed. |
Only works on matched bboxs, still have a lot of work to do. Does it need to support unmatched situation(N bboxs x M bboxs)? |
There are some tricks by using broadcasts to calculate unmatched gious. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Some doc needed to change. Thank you very much for the contribution!
return 1 - giou | ||
|
||
|
||
def do_giou_calculate(b1, b2, mode='giou'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make this function private? Like _calculate_giou
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a common function, maybe I should migrate it to tfa.image module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm, I would suggest leave it private here; let's move it in the future if there is any request!
"""Implements the GIoU loss function. | ||
|
||
GIoU loss was first introduced in the | ||
[Generalized Intersection over Union paper](https://giou.stanford.edu/GIoU.pdf). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
GIoU loss was first introduced in the | ||
[Generalized Intersection over Union paper](https://giou.stanford.edu/GIoU.pdf). | ||
GIoU is a enhance for model which use IOU in object detection. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: ... which uses IoU in ...
gl = tfa.losses.GIoU() | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]]) | ||
loss = gl(boxes1,boxes2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss = gl(boxes1,boxes2) | |
loss = gl(boxes1, boxes2) |
|
||
```python | ||
model = tf.keras.Model(inputs, outputs) | ||
model.compile('sgd', loss=tf.keras.losses.GIoULoss()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.compile('sgd', loss=tf.keras.losses.GIoULoss()) | |
model.compile('sgd', loss=tfa.losses.GIoULoss()) |
``` | ||
|
||
Args: | ||
mode: one of ['giou', 'iou'], decided to calculate giou loss or iou loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: ... calculate GIoU or IoU loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fsx950223 What happens when y_pred
is integer instead of floating point?
tensorflow_addons/losses/README.md
Outdated
@@ -5,6 +5,7 @@ | |||
|:---------- |:----------- |:------------- | | |||
| contrastive | @WindQAQ | [email protected] | | |||
| focal_loss | @SSaishruthi | [email protected] | | |||
| giou_loss | @who who who | [email protected] | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put your GitHub username @fsx950223
instead of @who who who
?
[Generalized Intersection over Union: | ||
A Metric and A Loss for Bounding Box Regression] | ||
(https://giou.stanford.edu/GIoU.pdf). | ||
GIoU is a enhance for model which use IoU in object detection. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: a enhance
-> an enhacement
Nit: model
-> models
""" | ||
Args: | ||
y_true: true targets tensor. The coordinates of the each bounding | ||
box in boxes are encoded as [y_min, x_min, y_max, x_max]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add extra four space at the start of this line.
y_true: true targets tensor. The coordinates of the each bounding | ||
box in boxes are encoded as [y_min, x_min, y_max, x_max]. | ||
y_pred: predictions tensor. The coordinates of the each bounding | ||
box in boxes are encoded as [y_min, x_min, y_max, x_max]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
y_true = tf.cast(y_true, y_pred.dtype) | ||
giou = _calculate_giou(y_pred, y_true, mode) | ||
|
||
# compute the final loss and return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment not required. It's obvious what the code is doing.
""" | ||
Args: | ||
b1: bounding box. The coordinates of the each bounding box in boxes are | ||
encoded as [y_min, x_min, y_max, x_max]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 spaces at start.
b1: bounding box. The coordinates of the each bounding box in boxes are | ||
encoded as [y_min, x_min, y_max, x_max]. | ||
b2: the other bounding box. The coordinates of the each bounding box | ||
in boxes are encoded as [y_min, x_min, y_max, x_max]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fsx950223 Thanks for the changes. Just a few small nits.
@@ -86,20 +86,21 @@ def giou_loss(y_true, y_pred, mode='giou'): | |||
if mode not in ['giou', 'iou']: | |||
raise ValueError("Value of mode should be 'iou' or 'giou'") | |||
y_pred = tf.convert_to_tensor(y_pred) | |||
if y_pred.dtype.is_floating is not True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be written as:
if not y_pred.dtype.is_floating:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fsx950223 LGTM!
cc: @WindQAQ
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution!
Implementation of GIOU loss