-
Notifications
You must be signed in to change notification settings - Fork 614
Add GIOU loss #477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add GIOU loss #477
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
ca4197d
Add GIOU loss
fsx950223 27ae8ea
Refact giou calculate
fsx950223 ae1a020
Fix doc
fsx950223 1606430
Update Readme.md
fsx950223 eb7458d
Format code
fsx950223 be8dcf0
refact calculate
fsx950223 7f72cab
fix document
fsx950223 099298c
fix readme
fsx950223 825c255
fix docs
fsx950223 9beec61
Change to official api
fsx950223 72fc21a
format code
fsx950223 b35dff9
enhance robust
fsx950223 310850d
add box format
fsx950223 f85f692
add keras test
fsx950223 ec11713
add one bbox test
fsx950223 8d4b3dd
add different shapes test case
fsx950223 55f4d0b
format code
fsx950223 f830058
fix docs
fsx950223 e42ccd5
make private
fsx950223 a03a320
add interger test
fsx950223 4b915d1
format code
fsx950223 f8ae33c
change expression
fsx950223 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|:---------- |:----------- |:------------- | | ||
| contrastive | @WindQAQ | [email protected] | | ||
| focal_loss | @SSaishruthi | [email protected] | | ||
| giou_loss | @fsx950223 | [email protected] | | ||
| lifted | @rahulunair | [email protected] | | ||
| npairs | @WindQAQ | [email protected] | | ||
| sparsemax_loss | @AndreasMadsen | [email protected] | | ||
|
@@ -15,6 +16,7 @@ | |
|:----------------------- |:---------------------|:--------------------------| | ||
| contrastive | ContrastiveLoss | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | | ||
| focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 | | ||
| giou_loss | GIoULoss | https://giou.stanford.edu/GIoU.pdf | | ||
| lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 | | ||
| npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | | ||
| npairs | NpairsMultilabelLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Implements GIoU loss.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow as tf | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package='Addons') | ||
class GIoULoss(tf.keras.losses.Loss): | ||
"""Implements the GIoU loss function. | ||
|
||
GIoU loss was first introduced in the | ||
[Generalized Intersection over Union: | ||
A Metric and A Loss for Bounding Box Regression] | ||
(https://giou.stanford.edu/GIoU.pdf). | ||
GIoU is an enhancement for models which use IoU in object detection. | ||
|
||
Usage: | ||
|
||
```python | ||
gl = tfa.losses.GIoU() | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]]) | ||
loss = gl(boxes1, boxes2) | ||
print('Loss: ', loss.numpy()) # Loss: [1.07500000298023224, 1.9333333373069763] | ||
``` | ||
Usage with tf.keras API: | ||
|
||
```python | ||
model = tf.keras.Model(inputs, outputs) | ||
model.compile('sgd', loss=tfa.losses.GIoULoss()) | ||
``` | ||
|
||
Args: | ||
mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. | ||
""" | ||
|
||
def __init__(self, | ||
mode='giou', | ||
reduction=tf.keras.losses.Reduction.AUTO, | ||
name='giou_loss'): | ||
if mode not in ['giou', 'iou']: | ||
raise ValueError("Value of mode should be 'iou' or 'giou'") | ||
super(GIoULoss, self).__init__(name=name, reduction=reduction) | ||
fsx950223 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.mode = mode | ||
|
||
def get_config(self): | ||
base_config = super(GIoULoss, self).get_config() | ||
base_config['mode'] = self.mode | ||
return base_config | ||
fsx950223 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def call(self, y_true, y_pred): | ||
return giou_loss(y_true, y_pred, mode=self.mode) | ||
|
||
|
||
@tf.keras.utils.register_keras_serializable(package='Addons') | ||
@tf.function | ||
def giou_loss(y_true, y_pred, mode='giou'): | ||
""" | ||
Args: | ||
y_true: true targets tensor. The coordinates of the each bounding | ||
box in boxes are encoded as [y_min, x_min, y_max, x_max]. | ||
y_pred: predictions tensor. The coordinates of the each bounding | ||
box in boxes are encoded as [y_min, x_min, y_max, x_max]. | ||
mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. | ||
|
||
Returns: | ||
GIoU loss float `Tensor`. | ||
""" | ||
if mode not in ['giou', 'iou']: | ||
raise ValueError("Value of mode should be 'iou' or 'giou'") | ||
y_pred = tf.convert_to_tensor(y_pred) | ||
if not y_pred.dtype.is_floating: | ||
y_pred = tf.cast(y_pred, tf.float32) | ||
y_true = tf.cast(y_true, y_pred.dtype) | ||
giou = _calculate_giou(y_pred, y_true, mode) | ||
|
||
return 1 - giou | ||
|
||
|
||
def _calculate_giou(b1, b2, mode='giou'): | ||
""" | ||
Args: | ||
b1: bounding box. The coordinates of the each bounding box in boxes are | ||
encoded as [y_min, x_min, y_max, x_max]. | ||
b2: the other bounding box. The coordinates of the each bounding box | ||
in boxes are encoded as [y_min, x_min, y_max, x_max]. | ||
mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. | ||
|
||
Returns: | ||
GIoU loss float `Tensor`. | ||
""" | ||
fsx950223 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
zero = tf.convert_to_tensor(0., b1.dtype) | ||
b1_ymin, b1_xmin, b1_ymax, b1_xmax = tf.unstack(b1, 4, axis=-1) | ||
b2_ymin, b2_xmin, b2_ymax, b2_xmax = tf.unstack(b2, 4, axis=-1) | ||
b1_width = tf.maximum(zero, b1_xmax - b1_xmin) | ||
b1_height = tf.maximum(zero, b1_ymax - b1_ymin) | ||
b2_width = tf.maximum(zero, b2_xmax - b2_xmin) | ||
b2_height = tf.maximum(zero, b2_ymax - b2_ymin) | ||
b1_area = b1_width * b1_height | ||
b2_area = b2_width * b2_height | ||
|
||
intersect_ymin = tf.maximum(b1_ymin, b2_ymin) | ||
intersect_xmin = tf.maximum(b1_xmin, b2_xmin) | ||
intersect_ymax = tf.minimum(b1_ymax, b2_ymax) | ||
intersect_xmax = tf.minimum(b1_xmax, b2_xmax) | ||
intersect_width = tf.maximum(zero, intersect_xmax - intersect_xmin) | ||
intersect_height = tf.maximum(zero, intersect_ymax - intersect_ymin) | ||
intersect_area = intersect_width * intersect_height | ||
|
||
union_area = b1_area + b2_area - intersect_area | ||
iou = tf.math.divide_no_nan(intersect_area, union_area) | ||
if mode == 'iou': | ||
return iou | ||
|
||
enclose_ymin = tf.minimum(b1_ymin, b2_ymin) | ||
enclose_xmin = tf.minimum(b1_xmin, b2_xmin) | ||
enclose_ymax = tf.maximum(b1_ymax, b2_ymax) | ||
enclose_xmax = tf.maximum(b1_xmax, b2_xmax) | ||
enclose_width = tf.maximum(zero, enclose_xmax - enclose_xmin) | ||
enclose_height = tf.maximum(zero, enclose_ymax - enclose_ymin) | ||
enclose_area = enclose_width * enclose_height | ||
giou = iou - tf.math.divide_no_nan( | ||
(enclose_area - union_area), enclose_area) | ||
return giou |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for GIoU loss.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from absl.testing import parameterized | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow_addons.utils import test_utils | ||
from tensorflow_addons.losses import giou_loss, GIoULoss | ||
|
||
|
||
@test_utils.run_all_in_graph_and_eager_modes | ||
class GIoULossTest(tf.test.TestCase, parameterized.TestCase): | ||
"""GIoU test class.""" | ||
|
||
def test_config(self): | ||
gl_obj = GIoULoss( | ||
reduction=tf.keras.losses.Reduction.NONE, name='giou_loss') | ||
self.assertEqual(gl_obj.name, 'giou_loss') | ||
self.assertEqual(gl_obj.reduction, tf.keras.losses.Reduction.NONE) | ||
|
||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_iou(self, dtype): | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], | ||
dtype=dtype) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], | ||
dtype=dtype) | ||
expected_result = tf.constant([0.875, 1.], dtype=dtype) | ||
loss = giou_loss(boxes1, boxes2, mode='iou') | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_giou_loss(self, dtype): | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], | ||
dtype=dtype) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], | ||
dtype=dtype) | ||
expected_result = tf.constant( | ||
[1.07500000298023224, 1.9333333373069763], dtype=dtype) | ||
loss = giou_loss(boxes1, boxes2) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
def test_with_integer(self): | ||
boxes1 = tf.constant([[4, 3, 7, 5], [5, 6, 10, 7]], dtype=tf.int32) | ||
boxes2 = tf.constant([[3, 4, 6, 8], [14, 14, 15, 15]], dtype=tf.int32) | ||
expected_result = tf.constant( | ||
[1.07500000298023224, 1.9333333373069763], dtype=tf.float32) | ||
loss = giou_loss(boxes1, boxes2) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_different_shapes(self, dtype): | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], | ||
dtype=dtype) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype) | ||
tf.expand_dims(boxes1, -2) | ||
tf.expand_dims(boxes2, 0) | ||
expected_result = tf.constant([1.07500000298023224, 1.366071], | ||
dtype=dtype) | ||
loss = giou_loss(boxes1, boxes2) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], | ||
dtype=dtype) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype) | ||
tf.expand_dims(boxes1, 0) | ||
tf.expand_dims(boxes2, -2) | ||
expected_result = tf.constant([1.07500000298023224, 1.366071], | ||
dtype=dtype) | ||
loss = giou_loss(boxes1, boxes2) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_one_bbox(self, dtype): | ||
boxes1 = tf.constant([4.0, 3.0, 7.0, 5.0], dtype=dtype) | ||
boxes2 = tf.constant([3.0, 4.0, 6.0, 8.0], dtype=dtype) | ||
expected_result = tf.constant(1.07500000298023224, dtype=dtype) | ||
loss = giou_loss(boxes1, boxes2) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
fsx950223 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@parameterized.named_parameters(("float16", np.float16), | ||
("float32", np.float32), | ||
("float64", np.float64)) | ||
def test_keras_model(self, dtype): | ||
boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], | ||
dtype=dtype) | ||
boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], | ||
dtype=dtype) | ||
expected_result = tf.constant( | ||
[1.07500000298023224, 1.9333333373069763], dtype=dtype) | ||
model = tf.keras.Sequential() | ||
model.compile( | ||
optimizer='adam', | ||
loss=GIoULoss(reduction=tf.keras.losses.Reduction.NONE)) | ||
loss = model.evaluate(boxes1, boxes2, batch_size=2, steps=1) | ||
self.assertAllCloseAccordingToType(loss, expected_result) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.