From 7a96222922f00ebf33e9d100605c269b1d960fb3 Mon Sep 17 00:00:00 2001 From: Allen Qin Date: Thu, 20 Aug 2020 11:38:55 +0000 Subject: [PATCH 1/4] added probabilities for generated text in inference mode --- mesh_tensorflow/transformer/utils.py | 40 ++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index e414c34d..babc22ed 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -518,11 +518,44 @@ def _verify_feature_exists(feature_name, should_exist): inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") + + # calculate probabilities for the output texts + # Replaces everything after EOS with 0 (along last dim). + eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32), + exclusive=True, dim=mtf_samples.shape[1]) + valid_ids = mtf.equal(eos_and_after, 0) + targets_for_score = mtf.where(valid_ids, mtf_samples, 0) + + logits, _ = transformer_model.call_simple( + inputs=inputs, + targets=targets_for_score, + compute_loss=False, + mode='score', + variable_dtype=get_variable_dtype()) + + # calculate log probability + targets = mtf_features["targets"] = targets_for_score + + batch_dim, length_dim, vocab_dim = logits.shape.dims + cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( + logits, mtf_features["targets"], vocab_dim) + cross_entropy *= mtf.cast( + mtf.not_equal(targets, 0), cross_entropy.dtype) + if mode == "delimited_lm": + cross_entropy *= mtf.cast(mtf.logical_not( + transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) + scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) + + # convert log prob to prob + probabilities = mtf.exp(scores) + probabilities = mtf.anonymize(probabilities) + mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = clean_decodes(lowering.export_to_tf_tensor(inputs)) outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples)) + probabilities = lowering.export_to_tf_tensor(probabilities) # Detokenize in the graph if supported by vocabulary and accelerator. def _maybe_detokenize(ids, vocab): @@ -535,7 +568,9 @@ def _maybe_detokenize(ids, vocab): predictions = { "inputs": inputs, - "outputs": outputs} + "outputs": outputs, + "probabilities": probabilities + } if mode in ["score", tf.estimator.ModeKeys.PREDICT]: # When exporting a model, we need to communicate to TF-Serving that @@ -1203,6 +1238,7 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1): return tf.where_v2(valid_ids, ids, pad_id) + def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, scores_filename, num_examples=None): """For each example returned by input_fn, compute log likelihood. @@ -2217,4 +2253,4 @@ def _input_fn(params, eval_dataset): else: raise ValueError( "unknown mode %s - must be train/perplexity_eval/eval/infer/export" - % mode) + % mode) \ No newline at end of file From 326c8e4dc686fd708939f56f7dfdbc60e1183b2d Mon Sep 17 00:00:00 2001 From: Allen Qin Date: Mon, 31 Aug 2020 11:05:32 +0000 Subject: [PATCH 2/4] change probabilities to scores(log likelihood) in the output to be consistent with other scores. create a compute_score function to remove duplicate code. --- mesh_tensorflow/transformer/utils.py | 63 +++++++++++++++------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index babc22ed..477eacbd 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -483,16 +483,9 @@ def _verify_feature_exists(feature_name, should_exist): compute_loss=False, mode=mode, variable_dtype=get_variable_dtype()) - batch_dim, length_dim, vocab_dim = logits.shape.dims - cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( - logits, mtf_features["targets"], vocab_dim) - cross_entropy *= mtf.cast( - mtf.not_equal(targets, 0), cross_entropy.dtype) - if model_type == "delimited_lm": - cross_entropy *= mtf.cast(mtf.logical_not( - transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) - scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) - scores = mtf.anonymize(scores) + + # calculate log likelihood + scores = compute_score(logits, targets, model_type) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) predictions = { "scores": lowering.export_to_tf_tensor(scores) @@ -533,29 +526,15 @@ def _verify_feature_exists(feature_name, should_exist): mode='score', variable_dtype=get_variable_dtype()) - # calculate log probability - targets = mtf_features["targets"] = targets_for_score - - batch_dim, length_dim, vocab_dim = logits.shape.dims - cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( - logits, mtf_features["targets"], vocab_dim) - cross_entropy *= mtf.cast( - mtf.not_equal(targets, 0), cross_entropy.dtype) - if mode == "delimited_lm": - cross_entropy *= mtf.cast(mtf.logical_not( - transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) - scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) - - # convert log prob to prob - probabilities = mtf.exp(scores) - probabilities = mtf.anonymize(probabilities) + # calculate log likelihood + scores = compute_score(logits, targets_for_score, model_type) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) inputs = clean_decodes(lowering.export_to_tf_tensor(inputs)) outputs = clean_decodes(lowering.export_to_tf_tensor(mtf_samples)) - probabilities = lowering.export_to_tf_tensor(probabilities) + scores = lowering.export_to_tf_tensor(scores) # Detokenize in the graph if supported by vocabulary and accelerator. def _maybe_detokenize(ids, vocab): @@ -569,9 +548,9 @@ def _maybe_detokenize(ids, vocab): predictions = { "inputs": inputs, "outputs": outputs, - "probabilities": probabilities + "scores": scores } - + if mode in ["score", tf.estimator.ModeKeys.PREDICT]: # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. @@ -1238,6 +1217,32 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1): return tf.where_v2(valid_ids, ids, pad_id) +def compute_score(logits, targets, model_type): + """Compute the log likelihood given logits and targets. + + Args: + logits: A mtf Tensor with floating-point dtype, containing the predicted + relative log probabilities of the classes. + targets: A mtf Tensor with integer dtype whose values are in the range + [0, vocab_dim.size). + model_type: a string. One of "bitransformer", "lm", "delimited_lm", + "aligned", or "bi_teacher_student" + + Returns: + a float mtf.Tensor with the log likelihood. + """ + batch_dim, length_dim, vocab_dim = logits.shape.dims + cross_entropy = mtf.layers.softmax_cross_entropy_with_logits( + logits, targets, vocab_dim) + cross_entropy *= mtf.cast( + mtf.not_equal(targets, 0), cross_entropy.dtype) + if model_type == "delimited_lm": + cross_entropy *= mtf.cast(mtf.logical_not( + transformer.delimited_lm_inputs_mask(targets)), cross_entropy.dtype) + scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) + scores = mtf.anonymize(scores) + return scores + def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, scores_filename, num_examples=None): From 7161da2313974aac11b7e23de55207d16ffce6e7 Mon Sep 17 00:00:00 2001 From: allen-q Date: Tue, 8 Sep 2020 20:04:40 +1000 Subject: [PATCH 3/4] update function compute_score to compute_scores --- mesh_tensorflow/transformer/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index 477eacbd..be18c04d 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -485,7 +485,7 @@ def _verify_feature_exists(feature_name, should_exist): variable_dtype=get_variable_dtype()) # calculate log likelihood - scores = compute_score(logits, targets, model_type) + scores = compute_scores(logits, targets, model_type) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) predictions = { "scores": lowering.export_to_tf_tensor(scores) @@ -527,7 +527,7 @@ def _verify_feature_exists(feature_name, should_exist): variable_dtype=get_variable_dtype()) # calculate log likelihood - scores = compute_score(logits, targets_for_score, model_type) + scores = compute_scores(logits, targets_for_score, model_type) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) @@ -1217,7 +1217,7 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1): return tf.where_v2(valid_ids, ids, pad_id) -def compute_score(logits, targets, model_type): +def compute_scores(logits, targets, model_type): """Compute the log likelihood given logits and targets. Args: @@ -2258,4 +2258,4 @@ def _input_fn(params, eval_dataset): else: raise ValueError( "unknown mode %s - must be train/perplexity_eval/eval/infer/export" - % mode) \ No newline at end of file + % mode) From 228fc925ab8cf95628c57135de5ea1366c82f32e Mon Sep 17 00:00:00 2001 From: allen-q Date: Sat, 12 Sep 2020 21:32:42 +1000 Subject: [PATCH 4/4] Resolve merge conflict. --- mesh_tensorflow/transformer/utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mesh_tensorflow/transformer/utils.py b/mesh_tensorflow/transformer/utils.py index be18c04d..fa7e83f5 100644 --- a/mesh_tensorflow/transformer/utils.py +++ b/mesh_tensorflow/transformer/utils.py @@ -483,11 +483,16 @@ def _verify_feature_exists(feature_name, should_exist): compute_loss=False, mode=mode, variable_dtype=get_variable_dtype()) - + # calculate log likelihood scores = compute_scores(logits, targets, model_type) + targets = mtf.anonymize(targets) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) + targets = clean_decodes(lowering.export_to_tf_tensor(targets)) + targets = _maybe_detokenize(targets, targets_vocabulary(vocabulary)) + predictions = { + "targets": targets, "scores": lowering.export_to_tf_tensor(scores) } elif mode == tf.estimator.ModeKeys.PREDICT: @@ -511,10 +516,10 @@ def _verify_feature_exists(feature_name, should_exist): inputs, variable_dtype=get_variable_dtype()) else: raise ValueError("unrecognized class") - - # calculate probabilities for the output texts + + # calculate probabilities for the output texts # Replaces everything after EOS with 0 (along last dim). - eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32), + eos_and_after = mtf.cumsum(mtf.cast(mtf.equal(mtf_samples, 1), tf.int32), exclusive=True, dim=mtf_samples.shape[1]) valid_ids = mtf.equal(eos_and_after, 0) targets_for_score = mtf.where(valid_ids, mtf_samples, 0) @@ -528,7 +533,7 @@ def _verify_feature_exists(feature_name, should_exist): # calculate log likelihood scores = compute_scores(logits, targets_for_score, model_type) - + mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) lowering = mtf.Lowering(graph, {mesh: mesh_impl}, autostack=autostack) @@ -550,7 +555,7 @@ def _maybe_detokenize(ids, vocab): "outputs": outputs, "scores": scores } - + if mode in ["score", tf.estimator.ModeKeys.PREDICT]: # When exporting a model, we need to communicate to TF-Serving that # master variables need to be copied to their slave slice variables. @@ -1219,11 +1224,11 @@ def clean_decodes(ids, eos_id=1, pad_id=0, length_axis=-1): def compute_scores(logits, targets, model_type): """Compute the log likelihood given logits and targets. - + Args: logits: A mtf Tensor with floating-point dtype, containing the predicted relative log probabilities of the classes. - targets: A mtf Tensor with integer dtype whose values are in the range + targets: A mtf Tensor with integer dtype whose values are in the range [0, vocab_dim.size). model_type: a string. One of "bitransformer", "lm", "delimited_lm", "aligned", or "bi_teacher_student" @@ -1242,7 +1247,7 @@ def compute_scores(logits, targets, model_type): scores = -mtf.reduce_sum(cross_entropy, reduced_dim=length_dim) scores = mtf.anonymize(scores) return scores - + def _score_with_estimator(estimator, input_fn, eval_checkpoint_step, model_dir, scores_filename, num_examples=None):