@@ -483,16 +483,9 @@ def _verify_feature_exists(feature_name, should_exist):
483
483
compute_loss = False ,
484
484
mode = mode ,
485
485
variable_dtype = get_variable_dtype ())
486
- batch_dim , length_dim , vocab_dim = logits .shape .dims
487
- cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
488
- logits , mtf_features ["targets" ], vocab_dim )
489
- cross_entropy *= mtf .cast (
490
- mtf .not_equal (targets , 0 ), cross_entropy .dtype )
491
- if model_type == "delimited_lm" :
492
- cross_entropy *= mtf .cast (mtf .logical_not (
493
- transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
494
- scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
495
- scores = mtf .anonymize (scores )
486
+
487
+ # calculate log likelihood
488
+ scores = compute_score (logits , targets , model_type )
496
489
lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
497
490
predictions = {
498
491
"scores" : lowering .export_to_tf_tensor (scores )
@@ -533,29 +526,15 @@ def _verify_feature_exists(feature_name, should_exist):
533
526
mode = 'score' ,
534
527
variable_dtype = get_variable_dtype ())
535
528
536
- # calculate log probability
537
- targets = mtf_features ["targets" ] = targets_for_score
538
-
539
- batch_dim , length_dim , vocab_dim = logits .shape .dims
540
- cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
541
- logits , mtf_features ["targets" ], vocab_dim )
542
- cross_entropy *= mtf .cast (
543
- mtf .not_equal (targets , 0 ), cross_entropy .dtype )
544
- if mode == "delimited_lm" :
545
- cross_entropy *= mtf .cast (mtf .logical_not (
546
- transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
547
- scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
548
-
549
- # convert log prob to prob
550
- probabilities = mtf .exp (scores )
551
- probabilities = mtf .anonymize (probabilities )
529
+ # calculate log likelihood
530
+ scores = compute_score (logits , targets_for_score , model_type )
552
531
553
532
mtf_samples = mtf .anonymize (mtf_samples )
554
533
inputs = mtf .anonymize (inputs )
555
534
lowering = mtf .Lowering (graph , {mesh : mesh_impl }, autostack = autostack )
556
535
inputs = clean_decodes (lowering .export_to_tf_tensor (inputs ))
557
536
outputs = clean_decodes (lowering .export_to_tf_tensor (mtf_samples ))
558
- probabilities = lowering .export_to_tf_tensor (probabilities )
537
+ scores = lowering .export_to_tf_tensor (scores )
559
538
560
539
# Detokenize in the graph if supported by vocabulary and accelerator.
561
540
def _maybe_detokenize (ids , vocab ):
@@ -569,9 +548,9 @@ def _maybe_detokenize(ids, vocab):
569
548
predictions = {
570
549
"inputs" : inputs ,
571
550
"outputs" : outputs ,
572
- "probabilities " : probabilities
551
+ "scores " : scores
573
552
}
574
-
553
+
575
554
if mode in ["score" , tf .estimator .ModeKeys .PREDICT ]:
576
555
# When exporting a model, we need to communicate to TF-Serving that
577
556
# master variables need to be copied to their slave slice variables.
@@ -1238,6 +1217,32 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1):
1238
1217
return tf .where_v2 (valid_ids , ids , pad_id )
1239
1218
1240
1219
1220
+ def compute_score (logits , targets , model_type ):
1221
+ """Compute the log likelihood given logits and targets.
1222
+
1223
+ Args:
1224
+ logits: A mtf Tensor with floating-point dtype, containing the predicted
1225
+ relative log probabilities of the classes.
1226
+ targets: A mtf Tensor with integer dtype whose values are in the range
1227
+ [0, vocab_dim.size).
1228
+ model_type: a string. One of "bitransformer", "lm", "delimited_lm",
1229
+ "aligned", or "bi_teacher_student"
1230
+
1231
+ Returns:
1232
+ a float mtf.Tensor with the log likelihood.
1233
+ """
1234
+ batch_dim , length_dim , vocab_dim = logits .shape .dims
1235
+ cross_entropy = mtf .layers .softmax_cross_entropy_with_logits (
1236
+ logits , targets , vocab_dim )
1237
+ cross_entropy *= mtf .cast (
1238
+ mtf .not_equal (targets , 0 ), cross_entropy .dtype )
1239
+ if model_type == "delimited_lm" :
1240
+ cross_entropy *= mtf .cast (mtf .logical_not (
1241
+ transformer .delimited_lm_inputs_mask (targets )), cross_entropy .dtype )
1242
+ scores = - mtf .reduce_sum (cross_entropy , reduced_dim = length_dim )
1243
+ scores = mtf .anonymize (scores )
1244
+ return scores
1245
+
1241
1246
1242
1247
def _score_with_estimator (estimator , input_fn , eval_checkpoint_step , model_dir ,
1243
1248
scores_filename , num_examples = None ):
0 commit comments