diff --git a/tensorflow_addons/losses/BUILD b/tensorflow_addons/losses/BUILD index 927f821fe2..c97351f0e4 100644 --- a/tensorflow_addons/losses/BUILD +++ b/tensorflow_addons/losses/BUILD @@ -8,6 +8,7 @@ py_library( "__init__.py", "contrastive.py", "focal_loss.py", + "giou_loss.py", "lifted.py", "metric_learning.py", "npairs.py", @@ -47,6 +48,19 @@ py_test( ], ) +py_test( + name = "giou_loss_test", + size = "small", + srcs = [ + "giou_loss_test.py", + ], + main = "giou_loss_test.py", + srcs_version = "PY2AND3", + deps = [ + ":losses", + ], +) + py_test( name = "npairs_test", size = "small", diff --git a/tensorflow_addons/losses/README.md b/tensorflow_addons/losses/README.md index ab54fef1bf..2b40ad2f5e 100644 --- a/tensorflow_addons/losses/README.md +++ b/tensorflow_addons/losses/README.md @@ -5,6 +5,7 @@ |:---------- |:----------- |:------------- | | contrastive | @WindQAQ | windqaq@gmail.com | | focal_loss | @SSaishruthi | saishruthi.tn@gmail.com | +| giou_loss | @fsx950223 | fsx950223@gmail.com | | lifted | @rahulunair | rahulunair@gmail.com | | npairs | @WindQAQ | windqaq@gmail.com | | sparsemax_loss | @AndreasMadsen | amwwebdk+github@gmail.com | @@ -15,6 +16,7 @@ |:----------------------- |:---------------------|:--------------------------| | contrastive | ContrastiveLoss | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | | focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 | +| giou_loss | GIoULoss | https://giou.stanford.edu/GIoU.pdf | | lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 | | npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | | npairs | NpairsMultilabelLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py index ff8e5094fa..7aa292527f 100644 --- a/tensorflow_addons/losses/__init__.py +++ b/tensorflow_addons/losses/__init__.py @@ -20,6 +20,7 @@ from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy +from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss, npairs_multilabel_loss, NpairsMultilabelLoss from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss diff --git a/tensorflow_addons/losses/giou_loss.py b/tensorflow_addons/losses/giou_loss.py new file mode 100644 index 0000000000..e864528ec2 --- /dev/null +++ b/tensorflow_addons/losses/giou_loss.py @@ -0,0 +1,141 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements GIoU loss.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class GIoULoss(tf.keras.losses.Loss): + """Implements the GIoU loss function. + + GIoU loss was first introduced in the + [Generalized Intersection over Union: + A Metric and A Loss for Bounding Box Regression] + (https://giou.stanford.edu/GIoU.pdf). + GIoU is an enhancement for models which use IoU in object detection. + + Usage: + + ```python + gl = tfa.losses.GIoU() + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]]) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]]) + loss = gl(boxes1, boxes2) + print('Loss: ', loss.numpy()) # Loss: [1.07500000298023224, 1.9333333373069763] + ``` + Usage with tf.keras API: + + ```python + model = tf.keras.Model(inputs, outputs) + model.compile('sgd', loss=tfa.losses.GIoULoss()) + ``` + + Args: + mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. + """ + + def __init__(self, + mode='giou', + reduction=tf.keras.losses.Reduction.AUTO, + name='giou_loss'): + if mode not in ['giou', 'iou']: + raise ValueError("Value of mode should be 'iou' or 'giou'") + super(GIoULoss, self).__init__(name=name, reduction=reduction) + self.mode = mode + + def get_config(self): + base_config = super(GIoULoss, self).get_config() + base_config['mode'] = self.mode + return base_config + + def call(self, y_true, y_pred): + return giou_loss(y_true, y_pred, mode=self.mode) + + +@tf.keras.utils.register_keras_serializable(package='Addons') +@tf.function +def giou_loss(y_true, y_pred, mode='giou'): + """ + Args: + y_true: true targets tensor. The coordinates of the each bounding + box in boxes are encoded as [y_min, x_min, y_max, x_max]. + y_pred: predictions tensor. The coordinates of the each bounding + box in boxes are encoded as [y_min, x_min, y_max, x_max]. + mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. + + Returns: + GIoU loss float `Tensor`. + """ + if mode not in ['giou', 'iou']: + raise ValueError("Value of mode should be 'iou' or 'giou'") + y_pred = tf.convert_to_tensor(y_pred) + if not y_pred.dtype.is_floating: + y_pred = tf.cast(y_pred, tf.float32) + y_true = tf.cast(y_true, y_pred.dtype) + giou = _calculate_giou(y_pred, y_true, mode) + + return 1 - giou + + +def _calculate_giou(b1, b2, mode='giou'): + """ + Args: + b1: bounding box. The coordinates of the each bounding box in boxes are + encoded as [y_min, x_min, y_max, x_max]. + b2: the other bounding box. The coordinates of the each bounding box + in boxes are encoded as [y_min, x_min, y_max, x_max]. + mode: one of ['giou', 'iou'], decided to calculate GIoU or IoU loss. + + Returns: + GIoU loss float `Tensor`. + """ + zero = tf.convert_to_tensor(0., b1.dtype) + b1_ymin, b1_xmin, b1_ymax, b1_xmax = tf.unstack(b1, 4, axis=-1) + b2_ymin, b2_xmin, b2_ymax, b2_xmax = tf.unstack(b2, 4, axis=-1) + b1_width = tf.maximum(zero, b1_xmax - b1_xmin) + b1_height = tf.maximum(zero, b1_ymax - b1_ymin) + b2_width = tf.maximum(zero, b2_xmax - b2_xmin) + b2_height = tf.maximum(zero, b2_ymax - b2_ymin) + b1_area = b1_width * b1_height + b2_area = b2_width * b2_height + + intersect_ymin = tf.maximum(b1_ymin, b2_ymin) + intersect_xmin = tf.maximum(b1_xmin, b2_xmin) + intersect_ymax = tf.minimum(b1_ymax, b2_ymax) + intersect_xmax = tf.minimum(b1_xmax, b2_xmax) + intersect_width = tf.maximum(zero, intersect_xmax - intersect_xmin) + intersect_height = tf.maximum(zero, intersect_ymax - intersect_ymin) + intersect_area = intersect_width * intersect_height + + union_area = b1_area + b2_area - intersect_area + iou = tf.math.divide_no_nan(intersect_area, union_area) + if mode == 'iou': + return iou + + enclose_ymin = tf.minimum(b1_ymin, b2_ymin) + enclose_xmin = tf.minimum(b1_xmin, b2_xmin) + enclose_ymax = tf.maximum(b1_ymax, b2_ymax) + enclose_xmax = tf.maximum(b1_xmax, b2_xmax) + enclose_width = tf.maximum(zero, enclose_xmax - enclose_xmin) + enclose_height = tf.maximum(zero, enclose_ymax - enclose_ymin) + enclose_area = enclose_width * enclose_height + giou = iou - tf.math.divide_no_nan( + (enclose_area - union_area), enclose_area) + return giou diff --git a/tensorflow_addons/losses/giou_loss_test.py b/tensorflow_addons/losses/giou_loss_test.py new file mode 100644 index 0000000000..4070d526b0 --- /dev/null +++ b/tensorflow_addons/losses/giou_loss_test.py @@ -0,0 +1,125 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for GIoU loss.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np +import tensorflow as tf +from tensorflow_addons.utils import test_utils +from tensorflow_addons.losses import giou_loss, GIoULoss + + +@test_utils.run_all_in_graph_and_eager_modes +class GIoULossTest(tf.test.TestCase, parameterized.TestCase): + """GIoU test class.""" + + def test_config(self): + gl_obj = GIoULoss( + reduction=tf.keras.losses.Reduction.NONE, name='giou_loss') + self.assertEqual(gl_obj.name, 'giou_loss') + self.assertEqual(gl_obj.reduction, tf.keras.losses.Reduction.NONE) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_iou(self, dtype): + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], + dtype=dtype) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], + dtype=dtype) + expected_result = tf.constant([0.875, 1.], dtype=dtype) + loss = giou_loss(boxes1, boxes2, mode='iou') + self.assertAllCloseAccordingToType(loss, expected_result) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_giou_loss(self, dtype): + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], + dtype=dtype) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], + dtype=dtype) + expected_result = tf.constant( + [1.07500000298023224, 1.9333333373069763], dtype=dtype) + loss = giou_loss(boxes1, boxes2) + self.assertAllCloseAccordingToType(loss, expected_result) + + def test_with_integer(self): + boxes1 = tf.constant([[4, 3, 7, 5], [5, 6, 10, 7]], dtype=tf.int32) + boxes2 = tf.constant([[3, 4, 6, 8], [14, 14, 15, 15]], dtype=tf.int32) + expected_result = tf.constant( + [1.07500000298023224, 1.9333333373069763], dtype=tf.float32) + loss = giou_loss(boxes1, boxes2) + self.assertAllCloseAccordingToType(loss, expected_result) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_different_shapes(self, dtype): + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], + dtype=dtype) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype) + tf.expand_dims(boxes1, -2) + tf.expand_dims(boxes2, 0) + expected_result = tf.constant([1.07500000298023224, 1.366071], + dtype=dtype) + loss = giou_loss(boxes1, boxes2) + self.assertAllCloseAccordingToType(loss, expected_result) + + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], + dtype=dtype) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0]], dtype=dtype) + tf.expand_dims(boxes1, 0) + tf.expand_dims(boxes2, -2) + expected_result = tf.constant([1.07500000298023224, 1.366071], + dtype=dtype) + loss = giou_loss(boxes1, boxes2) + self.assertAllCloseAccordingToType(loss, expected_result) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_one_bbox(self, dtype): + boxes1 = tf.constant([4.0, 3.0, 7.0, 5.0], dtype=dtype) + boxes2 = tf.constant([3.0, 4.0, 6.0, 8.0], dtype=dtype) + expected_result = tf.constant(1.07500000298023224, dtype=dtype) + loss = giou_loss(boxes1, boxes2) + self.assertAllCloseAccordingToType(loss, expected_result) + + @parameterized.named_parameters(("float16", np.float16), + ("float32", np.float32), + ("float64", np.float64)) + def test_keras_model(self, dtype): + boxes1 = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]], + dtype=dtype) + boxes2 = tf.constant([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0]], + dtype=dtype) + expected_result = tf.constant( + [1.07500000298023224, 1.9333333373069763], dtype=dtype) + model = tf.keras.Sequential() + model.compile( + optimizer='adam', + loss=GIoULoss(reduction=tf.keras.losses.Reduction.NONE)) + loss = model.evaluate(boxes1, boxes2, batch_size=2, steps=1) + self.assertAllCloseAccordingToType(loss, expected_result) + + +if __name__ == '__main__': + tf.test.main()