Skip to content

Add GIOU loss #477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Nov 21, 2019
Merged

Add GIOU loss #477

merged 22 commits into from
Nov 21, 2019

Conversation

fsx950223
Copy link
Member

@fsx950223 fsx950223 commented Sep 4, 2019

Implementation of GIOU loss

@fsx950223 fsx950223 changed the title WIP: Add GIOU loss [WIP] Add GIOU loss Sep 5, 2019
@fsx950223 fsx950223 changed the title [WIP] Add GIOU loss Add GIOU loss Sep 5, 2019
@fsx950223 fsx950223 force-pushed the loss branch 2 times, most recently from 54732a3 to 4568ba4 Compare September 9, 2019 01:35
@seanpmorgan
Copy link
Member

Hi @fsx950223 thank you for this contribution and #493 . Just wanted to let you know that we'll review ASAP, but are a bit backed up given the work that has been going into building packages for newly released TF2 packages.

Also, just for the future.. if you wouldn't mind opening an issue to discuss before submitting a PR that'd be much appreciated.

@fsx950223
Copy link
Member Author

Hi @fsx950223 thank you for this contribution and #493 . Just wanted to let you know that we'll review ASAP, but are a bit backed up given the work that has been going into building packages for newly released TF2 packages.

Also, just for the future.. if you wouldn't mind opening an issue to discuss before submitting a PR that'd be much appreciated.

OK

@facaiy facaiy requested a review from a team September 18, 2019 06:13
@facaiy
Copy link
Member

facaiy commented Sep 18, 2019

@Squadrick Hi, Dheeraj, could you take a look if you have time? Thanks :-)

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @fsx950223 , sorry for the late reply. I have taken a deeper look into the paper, but want to give you some general feedback on this. Also, if it's possible, could you elaborate how to compute GIoU in test cases or any other places? Thanks!

@fsx950223
Copy link
Member Author

Hi @fsx950223 , sorry for the late reply. I have taken a deeper look into the paper, but want to give you some general feedback on this. Also, if it's possible, could you elaborate how to compute GIoU in test cases or any other places? Thanks!

I will modify the code on the weekend.

@fsx950223
Copy link
Member Author

Changed. @WindQAQ

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only some documentation issues need to be addressed. Thank you!

"""
Args
b1: bbox.
b2: the other bbox.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we rename bbox to bounding box here? It should be clearer in docstring.

@fsx950223
Copy link
Member Author

Changed.

@fsx950223
Copy link
Member Author

fsx950223 commented Nov 13, 2019

Only works on matched bboxs, still have a lot of work to do. Does it need to support unmatched situation(N bboxs x M bboxs)?

@fsx950223 fsx950223 changed the title [WIP]: Add GIOU loss Add GIOU loss Nov 13, 2019
@fsx950223
Copy link
Member Author

There are some tricks by using broadcasts to calculate unmatched gious.

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Some doc needed to change. Thank you very much for the contribution!

return 1 - giou


def do_giou_calculate(b1, b2, mode='giou'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make this function private? Like _calculate_giou.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a common function, maybe I should migrate it to tfa.image module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Umm, I would suggest leave it private here; let's move it in the future if there is any request!

"""Implements the GIoU loss function.

GIoU loss was first introduced in the
[Generalized Intersection over Union paper](https://giou.stanford.edu/GIoU.pdf).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


GIoU loss was first introduced in the
[Generalized Intersection over Union paper](https://giou.stanford.edu/GIoU.pdf).
GIoU is a enhance for model which use IOU in object detection.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: ... which uses IoU in ...

gl = tfa.losses.GIoU()
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]])
loss = gl(boxes1,boxes2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss = gl(boxes1,boxes2)
loss = gl(boxes1, boxes2)


```python
model = tf.keras.Model(inputs, outputs)
model.compile('sgd', loss=tf.keras.losses.GIoULoss())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model.compile('sgd', loss=tf.keras.losses.GIoULoss())
model.compile('sgd', loss=tfa.losses.GIoULoss())

```

Args:
mode: one of ['giou', 'iou'], decided to calculate giou loss or iou loss.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: ... calculate GIoU or IoU loss.

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 What happens when y_pred is integer instead of floating point?

@@ -5,6 +5,7 @@
|:---------- |:----------- |:------------- |
| contrastive | @WindQAQ | [email protected] |
| focal_loss | @SSaishruthi | [email protected] |
| giou_loss | @who who who | [email protected] |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put your GitHub username @fsx950223 instead of @who who who?

[Generalized Intersection over Union:
A Metric and A Loss for Bounding Box Regression]
(https://giou.stanford.edu/GIoU.pdf).
GIoU is a enhance for model which use IoU in object detection.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: a enhance -> an enhacement
Nit: model -> models

"""
Args:
y_true: true targets tensor. The coordinates of the each bounding
box in boxes are encoded as [y_min, x_min, y_max, x_max].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add extra four space at the start of this line.

y_true: true targets tensor. The coordinates of the each bounding
box in boxes are encoded as [y_min, x_min, y_max, x_max].
y_pred: predictions tensor. The coordinates of the each bounding
box in boxes are encoded as [y_min, x_min, y_max, x_max].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

y_true = tf.cast(y_true, y_pred.dtype)
giou = _calculate_giou(y_pred, y_true, mode)

# compute the final loss and return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment not required. It's obvious what the code is doing.

"""
Args:
b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 spaces at start.

b1: bounding box. The coordinates of the each bounding box in boxes are
encoded as [y_min, x_min, y_max, x_max].
b2: the other bounding box. The coordinates of the each bounding box
in boxes are encoded as [y_min, x_min, y_max, x_max].
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 Thanks for the changes. Just a few small nits.

@@ -86,20 +86,21 @@ def giou_loss(y_true, y_pred, mode='giou'):
if mode not in ['giou', 'iou']:
raise ValueError("Value of mode should be 'iou' or 'giou'")
y_pred = tf.convert_to_tensor(y_pred)
if y_pred.dtype.is_floating is not True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be written as:

if not y_pred.dtype.is_floating:

Copy link
Member

@Squadrick Squadrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fsx950223 LGTM!

cc: @WindQAQ

Copy link
Member

@WindQAQ WindQAQ left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution!

@WindQAQ WindQAQ merged commit 5d7dee6 into tensorflow:master Nov 21, 2019
@fsx950223 fsx950223 mentioned this pull request Jan 22, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants