Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Add loss functions for multiple-target objectives for distillation. #291

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions mesh_tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,105 @@ def softmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0):
return loss


def kl_divergence(y_true, y_pred, reduced_dim, weights=None, epsilon=1e-6):
"""Kullback-Leibler-Divergence between `y_true` and `y_pred`.

Computes: `loss = y_true * log(y_true / y_pred)`
From: tf.keras.losses.KLDivergence (Custom implementation with mtf)
See: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

Args:
y_true: mtf.Tensor, target predictions (distribution).
y_pred: mtf.Tensor, actual predictions (distribution).
reduced_dim: mtf.Dimension, reduction dimension for sum.
weights: Optional mtf.Tensor, indicator for padded regions.
epsilon: float, minimum value for numerical stability.
Returns:
scalar: K-L Divergence loss.
Raises:
ValueError: if the shapes do not match or reduced_dim is not valid.
"""
if set(y_true.shape.dims) != set(y_pred.shape.dims):
raise ValueError(
"`y_true` and `y_pred` must be of the same shape. "
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
if reduced_dim not in y_true.shape.dims:
raise ValueError(
f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).")
if weights is None:
weights = 1.

def _clip(x, min_value, max_value):
# Clip values for numerical stability.
x = mtf.maximum(x, min_value)
x = mtf.minimum(x, max_value)
return x

y_true = _clip(y_true, epsilon, 1.)
y_pred = _clip(y_pred, epsilon, 1.)
return mtf.reduce_sum(weights * y_true * mtf.log(y_true / y_pred))


def mean_squared_error(y_true, y_pred, weights=None):
"""L2-Loss between `y_true` and `y_pred`.

Args:
y_true: mtf.Tensor, target logits.
y_pred: mtf.Tensor, actual logits.
weights: Optional mtf.Tensor, indicator for padded regions.
Returns:
scalar: L2 loss.
Raises:
ValueError: if the shapes do not match or reduced_dim is not valid.
"""
if set(y_true.shape.dims) != set(y_pred.shape.dims):
raise ValueError(
"`y_true` and `y_pred` must be of the same shape. "
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
if weights is None:
weights = 1.
return mtf.reduce_sum(weights * mtf.square(y_true - y_pred))


def cosine_embedding_distill(y_true, y_pred, reduced_dim, weights=None,
epsilon=1e-6):
"""Cosine embedding loss for distillation from teacher to student logits.

See: https://arxiv.org/abs/1910.01108 (DistilBert) and
https://github.com/huggingface/transformers/tree/master/examples/
research_projects/distillation.

Args:
y_true: mtf.Tensor, teacher logits.
y_pred: mtf.Tensor, student logits.
reduced_dim: mtf.Dimension, reduction dimension for sum.
weights: Optional mtf.Tensor, indicator for padded regions.
epsilon: float, for numerical stability.
Returns:
scalar: mean cosine embedding distance.
Raises:
ValueError: if the shapes do not match or reduced_dim is not valid.
"""
if set(y_true.shape.dims) != set(y_pred.shape.dims):
raise ValueError(
"`y_true` and `y_pred` must be of the same shape. "
f"Currently they are {y_true.shape.dims} and {y_pred.shape.dims}")
if reduced_dim not in y_true.shape.dims:
raise ValueError(
f"`reduced_dim` must be a valid dimension (from {y_true.shape.dims}).")
if weights is None:
weights = 1.

prod_sum = mtf.reduce_sum(y_true * y_pred, reduced_dim=reduced_dim)
y_true_sq_sum = mtf.reduce_sum(y_true * y_true, reduced_dim=reduced_dim)
y_pred_sq_sum = mtf.reduce_sum(y_pred * y_pred, reduced_dim=reduced_dim)
inv_denom = mtf.rsqrt(y_true_sq_sum * y_pred_sq_sum + epsilon)
cos = prod_sum * inv_denom
# TODO(vinaysrao): Turn this into a more general cosine embedding loss with
# a `targets` tensor.
return mtf.reduce_sum(weights * (1. - cos))


def sigmoid_cross_entropy_with_logits(logits, targets):
"""Sigmoid cross-entropy loss.

Expand Down