Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit a3d42a3

Browse files
author
allen-qin
committed
change likelihood to log likelihood in the output to be consistent with other scores. create a compute_score function to remove duplicate code.
1 parent 7a96222 commit a3d42a3

File tree

1 file changed

+34
-29
lines changed

1 file changed

+34
-29
lines changed

mesh_tensorflow/transformer/utils.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -483,16 +483,9 @@ def _verify_feature_exists(feature_name, should_exist):
483483
compute_loss=False,
484484
mode=mode,
485485
variable_dtype=get_variable_dtype())
486-
batch_dim, length_dim, vocab_dim = logits.shape.dims
487-
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
488-
logits, mtf_features["targets"], vocab_dim)
489-
cross_entropy *= mtf.cast(
490-
mtf.not_equal(targets, 0), cross_entropy.dtype)
491-
if model_type == "delimited_lm":
492-
cross_entropy *= mtf.cast(mtf.logical_not(
493-
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
494-
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
495-
scores = mtf.anonymize(scores)
486+
487+
# calculate log likelihood
488+
scores = compute_score(logits, targets, model_type)
496489
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
497490
predictions = {
498491
"scores": lowering.export_to_tf_tensor(scores)
@@ -533,29 +526,15 @@ def _verify_feature_exists(feature_name, should_exist):
533526
mode='score',
534527
variable_dtype=get_variable_dtype())
535528

536-
# calculate log probability
537-
targets = mtf_features["targets"] = targets_for_score
538-
539-
batch_dim, length_dim, vocab_dim = logits.shape.dims
540-
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
541-
logits, mtf_features["targets"], vocab_dim)
542-
cross_entropy *= mtf.cast(
543-
mtf.not_equal(targets, 0), cross_entropy.dtype)
544-
if mode == "delimited_lm":
545-
cross_entropy *= mtf.cast(mtf.logical_not(
546-
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
547-
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
548-
549-
# convert log prob to prob
550-
probabilities = mtf.exp(scores)
551-
probabilities = mtf.anonymize(probabilities)
529+
# calculate log likelihood
530+
scores = compute_score(logits, targets_for_score, model_type)
552531

553532
mtf_samples = mtf.anonymize(mtf_samples)
554533
inputs = mtf.anonymize(inputs)
555534
lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack)
556535
inputs = clean_decodes(lowering.export_to_tf_tensor(inputs))
557536
outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples))
558-
probabilities = lowering.export_to_tf_tensor(probabilities)
537+
scores = lowering.export_to_tf_tensor(scores)
559538

560539
# Detokenize in the graph if supported by vocabulary and accelerator.
561540
def _maybe_detokenize(ids, vocab):
@@ -569,9 +548,9 @@ def _maybe_detokenize(ids, vocab):
569548
predictions = {
570549
"inputs": inputs,
571550
"outputs": outputs,
572-
"probabilities": probabilities
551+
"scores": scores
573552
}
574-
553+
575554
if mode in ["score", tf.estimator.ModeKeys.PREDICT]:
576555
# When exporting a model, we need to communicate to TF-Serving that
577556
# master variables need to be copied to their slave slice variables.
@@ -1238,6 +1217,32 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
12381217
return tf.where_v2(valid_ids, ids, pad_id)
12391218

12401219

1220+
def compute_score(logits, targets, model_type):
1221+
"""Compute the log likelihood given logits and targets.
1222+
1223+
Args:
1224+
logits: A mtf Tensor with floating-point dtype, containing the predicted
1225+
relative log probabilities of the classes.
1226+
targets: A mtf Tensor with integer dtype whose values are in the range
1227+
[0, vocab_dim.size).
1228+
model_type: a string. One of "bitransformer", "lm", "delimited_lm",
1229+
"aligned", or "bi_teacher_student"
1230+
1231+
Returns:
1232+
a float mtf.Tensor with the log likelihood.
1233+
"""
1234+
batch_dim, length_dim, vocab_dim = logits.shape.dims
1235+
cross_entropy = mtf.layers.softmax_cross_entropy_with_logits(
1236+
logits, targets, vocab_dim)
1237+
cross_entropy *= mtf.cast(
1238+
mtf.not_equal(targets, 0), cross_entropy.dtype)
1239+
if model_type == "delimited_lm":
1240+
cross_entropy *= mtf.cast(mtf.logical_not(
1241+
transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype)
1242+
scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim)
1243+
scores = mtf.anonymize(scores)
1244+
return scores
1245+
12411246

12421247
def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir,
12431248
scores_filename, num_examples=None):

0 commit comments

Comments
 (0)