diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 392dbdcd..c03a7712 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -491,16 +491,9 @@ def _maybe_detokenize(ids, vocab): compute_loss=False, mode=mode, variable_dtype=get_variable_dtype()) - batch_dim, length_dim, vocab_dim = logits.shape.dims - cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( - logits, mtf_features["targets"], vocab_dim) - cross_entropy *= mtf.cast( - mtf.not_equal(targets, 0), cross_entropy.dtype) - if model_type == "delimited_lm": - cross_entropy *= mtf.cast(mtf.logical_not( - transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) - scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) - scores = mtf.anonymize(scores) + + # calculate log likelihood + scores = compute_scores(logits, targets, model_type) targets = mtf.anonymize(targets) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) targets = clean_decodes(lowering.export_to_tf_tensor(targets)) @@ -531,18 +524,39 @@ def _maybe_detokenize(ids, vocab): inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") + + # calculate probabilities for the output texts + # Replaces everything after EOS with 0 (along last dim). + eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32), + exclusive=True, dim=mtf_samples.shape[1]) + valid_ids = mtf.equal(eos_and_after, 0) + targets_for_score = mtf.where(valid_ids, mtf_samples, 0) + + logits, _ = transformer_model.call_simple( + inputs=inputs, + targets=targets_for_score, + compute_loss=False, + mode='score', + variable_dtype=get_variable_dtype()) + + # calculate log likelihood + scores = compute_scores(logits, targets_for_score, model_type) + mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = clean_decodes(lowering.export_to_tf_tensor(inputs)) outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples)) + scores = lowering.export_to_tf_tensor(scores) inputs = _maybe_detokenize(inputs, inputs_vocabulary(vocabulary)) outputs = _maybe_detokenize(outputs, targets_vocabulary(vocabulary)) predictions = { "inputs": inputs, - "outputs": outputs} + "outputs": outputs, + "scores": scores + } if mode in ["score", tf.estimator.ModeKeys.PREDICT]: # When exporting a model, we need to communicate to TF-Serving that @@ -1210,6 +1224,33 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1): return tf.where_v2(valid_ids, ids, pad_id) +def compute_scores(logits, targets, model_type): + """Compute the log likelihood given logits and targets. + + Args: + logits: A mtf Tensor with floating-point dtype, containing the predicted + relative log probabilities of the classes. + targets: A mtf Tensor with integer dtype whose values are in the range + [0, vocab_dim.size). + model_type: a string. One of "bitransformer", "lm", "delimited_lm", + "aligned", or "bi_teacher_student" + + Returns: + a float mtf.Tensor with the log likelihood. + """ + batch_dim, length_dim, vocab_dim = logits.shape.dims + cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( + logits, targets, vocab_dim) + cross_entropy *= mtf.cast( + mtf.not_equal(targets, 0), cross_entropy.dtype) + if model_type == "delimited_lm": + cross_entropy *= mtf.cast(mtf.logical_not( + transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) + scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) + scores = mtf.anonymize(scores) + return scores + + def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, scores_filename, num_examples=None): """For each example returned by input_fn, compute log likelihood.