From 9f15745528d1ad636a9a6e2d90349b8dd3f17023 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Mon, 8 Feb 2021 16:42:02 -0800 Subject: [PATCH] Add loss functions for multiple-target objectives for distillation. PiperOrigin-RevId: 356382304 --- mesh_tensorflow/layers.py | 99 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/mesh_tensorflow/layers.py b/mesh_tensorflow/layers.py index 938fb43c..f99cdab9 100644 --- a/mesh_tensorflow/layers.py +++ b/mesh_tensorflow/layers.py @@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): return loss +def kl_divergence(y_true, y_pred, reduced_dim, weights=None, epsilon=1e-6): + """Kullback-Leibler-Divergence between `y_true` and `y_pred`. + + Computes: `loss = y_true * log(y_true / y_pred)` + From: tf.keras.losses.KLDivergence (Custom implementation with mtf) + See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence + + Args: + y_true: mtf.Tensor, target predictions (distribution). + y_pred: mtf.Tensor, actual predictions (distribution). + reduced_dim: mtf.Dimension, reduction dimension for sum. + weights: Optional mtf.Tensor, indicator for padded regions. + epsilon: float, minimum value for numerical stability. + Returns: + scalar: K-L Divergence loss. + Raises: + ValueError: if the shapes do not match or reduced_dim is not valid. + """ + if set(y_true.shape.dims) != set(y_pred.shape.dims): + raise ValueError( + "`y_true` and `y_pred` must be of the same shape. " + f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}") + if reduced_dim not in y_true.shape.dims: + raise ValueError( + f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).") + if weights is None: + weights = 1. + + def _clip(x, min_value, max_value): + # Clip values for numerical stability. + x = mtf.maximum(x, min_value) + x = mtf.minimum(x, max_value) + return x + + y_true = _clip(y_true, epsilon, 1.) + y_pred = _clip(y_pred, epsilon, 1.) + return mtf.reduce_sum(weights * y_true * mtf.log(y_true / y_pred)) + + +def mean_squared_error(y_true, y_pred, weights=None): + """L2-Loss between `y_true` and `y_pred`. + + Args: + y_true: mtf.Tensor, target logits. + y_pred: mtf.Tensor, actual logits. + weights: Optional mtf.Tensor, indicator for padded regions. + Returns: + scalar: L2 loss. + Raises: + ValueError: if the shapes do not match or reduced_dim is not valid. + """ + if set(y_true.shape.dims) != set(y_pred.shape.dims): + raise ValueError( + "`y_true` and `y_pred` must be of the same shape. " + f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}") + if weights is None: + weights = 1. + return mtf.reduce_sum(weights * mtf.square(y_true - y_pred)) + + +def cosine_embedding_distill(y_true, y_pred, reduced_dim, weights=None, + epsilon=1e-6): + """Cosine embedding loss for distillation from teacher to student logits. + + See: https://arxiv.org/abs/1910.01108 (DistilBert) and + https://github.com/huggingface/transformers/tree/master/examples/ + research_projects/distillation. + + Args: + y_true: mtf.Tensor, teacher logits. + y_pred: mtf.Tensor, student logits. + reduced_dim: mtf.Dimension, reduction dimension for sum. + weights: Optional mtf.Tensor, indicator for padded regions. + epsilon: float, for numerical stability. + Returns: + scalar: mean cosine embedding distance. + Raises: + ValueError: if the shapes do not match or reduced_dim is not valid. + """ + if set(y_true.shape.dims) != set(y_pred.shape.dims): + raise ValueError( + "`y_true` and `y_pred` must be of the same shape. " + f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}") + if reduced_dim not in y_true.shape.dims: + raise ValueError( + f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).") + if weights is None: + weights = 1. + + prod_sum = mtf.reduce_sum(y_true * y_pred, reduced_dim=reduced_dim) + y_true_sq_sum = mtf.reduce_sum(y_true * y_true, reduced_dim=reduced_dim) + y_pred_sq_sum = mtf.reduce_sum(y_pred * y_pred, reduced_dim=reduced_dim) + inv_denom = mtf.rsqrt(y_true_sq_sum * y_pred_sq_sum + epsilon) + cos = prod_sum * inv_denom + # TODO(vinaysrao): Turn this into a more general cosine embedding loss with + # a `targets` tensor. + return mtf.reduce_sum(weights * (1. - cos)) + + def sigmoid_cross_entropy_with_logits(logits, targets): """Sigmoid cross-entropy loss.